Exemplo n.º 1
0
def compute_splices(bbdb,
                    bbpairs,
                    verbosity,
                    parallel,
                    pbar,
                    pbar_interval=10.0,
                    **kw):
    bbpairs_shuf = bbpairs.copy()
    shuffle(bbpairs_shuf)
    exe = InProcessExecutor()
    if parallel:
        exe = cf.ProcessPoolExecutor(max_workers=parallel)
    with exe as pool:
        futures = list()
        for bbpair in bbpairs_shuf:
            bbw0 = BBlockWrap(bbdb.bblock(bbpair[0]))
            bbw1 = BBlockWrap(bbdb.bblock(bbpair[1]))
            f = pool.submit(_valid_splice_pairs, bbw0, bbw1, **kw)
            f.bbpair = bbpair
            futures.append(f)
        print("batch compute_splices, npairs:", len(futures))
        fiter = cf.as_completed(futures)
        if pbar:
            fiter = tqdm(fiter,
                         "precache splices",
                         mininterval=pbar_interval,
                         total=len(futures))
        res = {f.bbpair: f.result() for f in fiter}
    return {bbpair: res[bbpair] for bbpair in bbpairs}
Exemplo n.º 2
0
def get_allowed_splices(
        u,
        ublks,
        v,
        vblks,
        splicedb=None,
        splice_max_rms=0.7,
        splice_ncontact_cut=30,
        splice_clash_d2=4.0**2,  # ca only
        splice_contact_d2=8.0**2,
        splice_rms_range=6,
        splice_clash_contact_range=60,
        splice_clash_contact_by_helix=True,
        splice_ncontact_no_helix_cut=0,
        splice_nhelix_contacted_cut=0,
        splice_max_chain_length=999999,
        skip_on_fail=True,
        parallel=False,
        verbosity=1,
        cache_sync=0.001,
        precache_splices=False,
        pbar=False,
        pbar_interval=10.0,
        **kw):
    assert (u.dirn[1] + v.dirn[0]) == 1, 'get_allowed_splices dirn mismatch'

    # note: this is duplicated in edge_batch.py and they need to be the same
    params = (splice_max_rms, splice_ncontact_cut, splice_clash_d2,
              splice_contact_d2, splice_rms_range, splice_clash_contact_range,
              splice_clash_contact_by_helix, splice_ncontact_no_helix_cut,
              splice_nhelix_contacted_cut, splice_max_chain_length)

    outidx = _get_outidx(u.inout[:, 1])
    outblk = u.ibblock[outidx]
    outres = u.ires[outidx, 1]

    inblk = v.ibblock[v.inbreaks[:-1]]
    inres = v.ires[v.inbreaks[:-1], 0]
    inblk_breaks = contig_idx_breaks(inblk)

    outblk_res = defaultdict(list)
    for iblk, ires in zip(outblk, outres):
        outblk_res[iblk].append(ires)
    for iblk in outblk_res.keys():
        outblk_res[iblk] = np.array(outblk_res[iblk], 'i4')

    inblk_res = defaultdict(list)
    for iblk, ires in zip(inblk, inres):
        inblk_res[iblk].append(ires)
    for iblk in inblk_res.keys():
        inblk_res[iblk] = np.array(inblk_res[iblk], 'i4')
        assert np.all(sorted(inblk_res[iblk]) == inblk_res[iblk])

    nout = sum(len(a) for a in outblk_res.values())
    nent = sum(len(a) for a in inblk_res.values())
    valid_splices = [list() for i in range(nout)]

    swapped = False
    if u.dirn[1] == 0:  # swap so N-to-C!
        swapped = True
        u, ublks, v, vblks = v, vblks, u, ublks
        outblk_res, inblk_res = inblk_res, outblk_res
        outblk, inblk = inblk, outblk

    pairs_with_no_valid_splices = 0
    tcache = 0

    exe = InProcessExecutor()
    if parallel:
        exe = cf.ProcessPoolExecutor(max_workers=parallel)
    # exe = cf.ThreadPoolExecutor(max_workers=parallel) if parallel else InProcessExecutor()
    with exe as pool:
        futures = list()
        ofst0 = 0
        for iblk0, ires0 in outblk_res.items():
            blk0 = ublks[iblk0]
            key0 = blk0.filehash
            t = time()
            cache = splicedb.partial(params, key0) if splicedb else None
            tcache += time() - t
            ofst1 = 0
            for iblk1, ires1 in inblk_res.items():
                blk1 = vblks[iblk1]
                key1 = blk1.filehash
                if cache and key1 in cache and cache[key1]:
                    splices = cache[key1]
                    future = NonFuture(splices, dummy=True)
                else:
                    future = pool.submit(
                        _jit_splice_metrics, blk0.chains, blk1.chains,
                        blk0.ncac, blk1.ncac, blk0.stubs, blk1.stubs,
                        blk0.connections, blk1.connections, blk0.ss, blk1.ss,
                        blk0.cb, blk1.cb, splice_clash_d2, splice_contact_d2,
                        splice_rms_range, splice_clash_contact_range,
                        splice_clash_contact_by_helix, splice_max_rms,
                        splice_max_chain_length, skip_on_fail)
                fs = (iblk0, iblk1, ofst0, ofst1, ires0, ires1)
                future.stash = fs
                futures.append(future)
                ofst1 += len(ires1)
            ofst0 += len(ires0)

        if verbosity > 0 and tcache > 1.0:
            print('get_allowed_splices read caches time:', tcache)

        future_iter = cf.as_completed(futures)
        if pbar and not precache_splices:
            future_iter = tqdm(cf.as_completed(futures),
                               'checking splices',
                               mininterval=pbar_interval,
                               total=len(futures))
        for future in future_iter:
            iblk0, iblk1, ofst0, ofst1, ires0, ires1 = future.stash
            result = future.result()
            if len(result) is 5 and isinstance(result[0], np.ndarray):
                # is newly computed result, not from cache
                rms, nclash, ncontact, ncnh, nhc = result
                ok = ((nclash == 0) * (rms <= splice_max_rms) *
                      (ncontact >= splice_ncontact_cut) *
                      (ncnh >= splice_ncontact_no_helix_cut) *
                      (nhc >= splice_nhelix_contacted_cut))
                result = _splice_respairs(ok, ublks[iblk0], vblks[iblk1])
                if np.sum(ok) == 0:
                    print('N no clash', np.sum(nclash == 0))
                    print('N rms', np.sum(rms <= splice_max_rms))
                    print('N contact', np.sum(ncontact >= splice_ncontact_cut))

                if splicedb:
                    key0 = ublks[iblk0].filehash  # C-term side
                    key1 = vblks[iblk1].filehash  # N-term side
                    splicedb.add(params, key0, key1, result)
                    if np.random.random() < cache_sync:
                        print('sync_to_disk splices data')
                        splicedb.sync_to_disk()

            if swapped:
                result = result[1], result[0]
                ires0, ires1 = ires1, ires0
                ofst0, ofst1 = ofst1, ofst0

            if len(result[0]) == 0:
                pairs_with_no_valid_splices += 1
                continue
            index_of_ires0 = _index_of_map(ires0, np.max(result[0]))
            index_of_ires1 = _index_of_map(ires1, np.max(result[1]))
            irs = index_of_ires0[result[0]]
            jrs = index_of_ires1[result[1]]
            ok = (irs >= 0) * (jrs >= 0)
            irs = irs[ok] + ofst0
            jrs = jrs[ok] + ofst1
            for ir, jr in zip(irs, jrs):
                valid_splices[ir].append(jr)

    if cache_sync > 0 and splicedb:
        splicedb.sync_to_disk()

    if pairs_with_no_valid_splices:
        print('pairs with no valid splices: ', pairs_with_no_valid_splices,
              'of',
              len(outblk_res) * len(inblk_res))

    return valid_splices, nout, nent
Exemplo n.º 3
0
def grow_linear(ssdag,
                loss_function=null_lossfunc,
                loss_threshold=2.0,
                last_bb_same_as=-1,
                parallel=0,
                monte_carlo=0,
                verbosity=0,
                merge_bblock=None,
                lbl='',
                pbar=False,
                pbar_interval=10.0,
                no_duplicate_bases=True,
                max_linear=1000000,
                **kw):
    verts = ssdag.verts
    edges = ssdag.edges
    if last_bb_same_as is None: last_bb_same_as = -1
    assert len(verts) > 1
    assert len(verts) == len(edges) + 1
    assert verts[0].dirn[0] == 2
    assert verts[-1].dirn[1] == 2
    for ivertex in range(len(verts) - 1):
        assert verts[ivertex].dirn[1] + verts[ivertex + 1].dirn[0] == 1

    # if isinstance(loss_function, types.FunctionType):
    #     if not 'NUMBA_DISABLE_JIT' in os.environ:
    #         loss_function = nb.njit(nogil=1, fastmath=1)

    exe = cf.ThreadPoolExecutor(
        max_workers=parallel) if parallel else InProcessExecutor()
    # exe = cf.ProcessPoolExecutor(max_workers=parallel) if parallel else InProcessExecutor()
    with exe as pool:
        bb_base = tuple([
            np.array([b.basehash if no_duplicate_bases else 0 for b in bb],
                     dtype=np.int64) for bb in ssdag.bbs
        ])
        verts_pickleable = [v._state for v in verts]
        edges_pickleable = [e._state for e in edges]
        kwargs = dict(
            bb_base=bb_base,
            verts_pickleable=verts_pickleable,
            edges_pickleable=edges_pickleable,
            loss_function=loss_function,
            loss_threshold=loss_threshold,
            last_bb_same_as=last_bb_same_as,
            nresults=0,
            isplice=0,
            splice_position=np.eye(4, dtype=vertex_xform_dtype),
            max_linear=max_linear,
        )
        futures = list()
        if monte_carlo:
            kwargs['fn'] = _grow_linear_mc_start
            kwargs['seconds'] = monte_carlo
            kwargs['ivertex_range'] = (0, verts[0].len)
            kwargs['merge_bblock'] = merge_bblock
            kwargs['lbl'] = lbl
            kwargs['verbosity'] = verbosity
            kwargs['pbar'] = pbar
            kwargs['pbar_interval'] = pbar_interval
            njob = cpu_count() if parallel else 1
            for ivert in range(njob):
                kwargs['threadno'] = ivert
                futures.append(pool.submit(**kwargs))
        else:
            kwargs['fn'] = _grow_linear_start
            nbatch = max(1, int(verts[0].len / 64 / cpu_count()))
            for ivert in range(0, verts[0].len, nbatch):
                ivert_end = min(verts[0].len, ivert + nbatch)
                kwargs['ivertex_range'] = ivert, ivert_end
                futures.append(pool.submit(**kwargs))
        results = list()
        if monte_carlo:
            for f in cf.as_completed(futures):
                results.append(f.result())
        else:
            desc = 'linear search ' + str(lbl)
            if merge_bblock is None: merge_bblock = 0
            fiter = cf.as_completed(futures)
            if pbar:
                fiter = tqdm(fiter,
                             desc=desc,
                             position=merge_bblock + 1,
                             mininterval=pbar_interval,
                             total=len(futures))
            for f in fiter:
                results.append(f.result())
    tot_stats = zero_search_stats()
    for i in range(len(tot_stats)):
        tot_stats[i][0] += sum([r.stats[i][0] for r in results])
    result = ResultJIT(pos=np.concatenate([r.pos for r in results]),
                       idx=np.concatenate([r.idx for r in results]),
                       err=np.concatenate([r.err for r in results]),
                       stats=tot_stats)
    result = remove_duplicate_results(result)
    order = np.argsort(result.err)
    return ResultJIT(pos=result.pos[order],
                     idx=result.idx[order],
                     err=result.err[order],
                     stats=result.stats)
Exemplo n.º 4
0
def simple_search_dag(criteria,
                      db=None,
                      nbblocks=100,
                      min_seg_len=15,
                      parallel=False,
                      verbosity=0,
                      timing=0,
                      modbbs=None,
                      make_edges=True,
                      merge_bblock=None,
                      precache_splices=False,
                      precache_only=False,
                      bbs=None,
                      only_seg=None,
                      source=None,
                      print_edge_summary=False,
                      no_duplicate_bases=False,
                      shuffle_bblocks=False,
                      use_saved_bblocks=False,
                      output_prefix='./worms',
                      **kw):
    bbdb, spdb = db
    queries, directions = zip(*criteria.bbspec)
    tdb = time()
    if bbs is None:
        bbs = list()
        savename = output_prefix + '_bblocks.pickle'

        if use_saved_bblocks and os.path.exists(savename):
            with open(savename, 'rb') as inp:
                bbnames_list = _pickle.load(inp)
            for bbnames in bbnames_list:
                bbs.append([bbdb.bblock(n) for n in bbnames])

        else:
            for iquery, query in enumerate(queries):
                msegs = [
                    i + len(queries) if i < 0 else i
                    for i in criteria.which_mergeseg()
                ]
                if iquery in msegs[1:]:
                    print('seg', iquery, 'repeating bblocks from', msegs[0])
                    bbs.append(bbs[msegs[0]])
                    continue
                bbs0 = bbdb.query(
                    query,
                    max_bblocks=nbblocks,
                    shuffle_bblocks=shuffle_bblocks,
                    parallel=parallel,
                )
                bbs.append(bbs0)

        bases = [
            Counter(bytes(b.base).decode('utf-8') for b in bbs0)
            for bbs0 in bbs
        ]
        assert len(bbs) == len(queries)
        for i, v in enumerate(bbs):
            assert len(v) > 0, 'no bblocks for query: "' + queries[i] + '"'
        print('bblock queries:', str(queries))
        print('bblock numbers:', [len(b) for b in bbs])
        print('bblocks id:', [id(b) for b in bbs])
        print('bblock0 id ', [id(b[0]) for b in bbs])
        print('base_counts:')
        for query, basecount in zip(queries, bases):
            counts = ' '.join(f'{k}: {c}' for k, c in basecount.items())
            print(f'   {query:10}', counts)

        if criteria.is_cyclic:
            for a, b in zip(bbs[criteria.from_seg], bbs[criteria.to_seg]):
                assert a is b
            bbs[criteria.to_seg] = bbs[criteria.from_seg]

        if use_saved_bblocks:
            bbnames = [[bytes(b.file).decode('utf-8') for b in bb]
                       for bb in bbs]
            with open(savename, 'wb') as out:
                _pickle.dump(bbnames, out)

    else:
        bbs = bbs.copy()

    assert len(bbs) == len(criteria.bbspec)
    if modbbs: modbbs(bbs)
    if merge_bblock is not None and merge_bblock >= 0:
        # print('which_mergeseg', criteria.bbspec, criteria.which_mergeseg())
        for i in criteria.which_mergeseg():
            bbs[i] = (bbs[i][merge_bblock], )

    tdb = time() - tdb
    # info(
    # f'bblock creation time {tdb:7.3f} num bbs: ' +
    # str([len(x) for x in bbs])
    # )

    if precache_splices:
        bbnames = [[bytes(bb.file) for bb in bbtup] for bbtup in bbs]
        bbpairs = set()
        # for bb1, bb2, dirn1 in zip(bbnames, bbnames[1:], directions):
        for i in range(len(bbnames) - 1):
            bb1 = bbnames[i]
            bb2 = bbnames[i + 1]
            dirn1 = directions[i]
            rev = dirn1[1] == 'N'
            if bbs[i] is bbs[i + 1]:
                bbpairs.update((a, a) for a in bb1)
            else:
                bbpairs.update(
                    (b, a) if rev else (a, b) for a in bb1 for b in bb2)
        precompute_splicedb(db,
                            bbpairs,
                            verbosity=verbosity,
                            parallel=parallel,
                            **kw)
    if precache_only:
        return bbs

    verts = [None] * len(queries)
    edges = [None] * len(queries[1:])
    if source:
        srcdirn = [''.join('NC_' [d] for d in source.verts[i].dirn)
                   for i in range(len(source.verts))] # yapf: disable
        srcverts, srcedges = list(), list()
        for i, bb in enumerate(bbs):
            for isrc, bbsrc in enumerate(source.bbs):
                if directions[i] != srcdirn[isrc]: continue
                if [b.filehash for b in bb] == [b.filehash for b in bbsrc]:
                    verts[i] = source.verts[isrc]
                    srcverts.append(isrc)
        for i, bb in enumerate(zip(bbs, bbs[1:])):
            bb0, bb1 = bb
            for isrc, bbsrc in enumerate(zip(source.bbs, source.bbs[1:])):
                bbsrc0, bbsrc1 = bbsrc
                if directions[i] != srcdirn[isrc]: continue
                if directions[i + 1] != srcdirn[isrc + 1]: continue
                he = [b.filehash for b in bb0] == [b.filehash for b in bbsrc0]
                he &= [b.filehash for b in bb1] == [b.filehash for b in bbsrc1]
                if not he: continue
                edges[i] = source.edges[isrc]
                srcedges.append(isrc)

    if not make_edges: edges = []

    tvertex = time()
    exe = InProcessExecutor()

    if parallel:
        exe = cf.ThreadPoolExecutor(max_workers=parallel)
    with exe as pool:
        if only_seg is not None:
            save = bbs, directions
            bbs = [bbs[only_seg]]
            directions = [directions[only_seg]]
            verts = [verts[only_seg]]
        futures = list()
        for i, bb in enumerate(bbs):
            dirn = directions[i]
            if verts[i] is None:
                futures.append(
                    pool.submit(Vertex, bb, dirn, min_seg_len=min_seg_len))
        verts_new = [f.result() for f in futures]
        isnone = [i for i in range(len(verts)) if verts[i] is None]
        for i, inone in enumerate(isnone):
            verts[inone] = verts_new[i]
        # print(i, len(verts_new), len(verts))
        if isnone:
            assert i + 1 == len(verts_new)
        assert all(v for v in verts)
        if only_seg is not None:
            verts = ([None] * only_seg + verts + [None] *
                     (len(queries) - only_seg - 1))
            bbs, directions = save
    tvertex = time() - tvertex
    # info(
    # f'vertex creation time {tvertex:7.3f} num verts ' +
    # str([v.len if v else 0 for v in verts])
    # )

    if make_edges:
        tedge = time()
        for i, e in enumerate(edges):
            if e is None:
                edges[i] = Edge(verts[i],
                                bbs[i],
                                verts[i + 1],
                                bbs[i + 1],
                                splicedb=spdb,
                                verbosity=verbosity,
                                precache_splices=precache_splices,
                                **kw)
        tedge = time() - tedge
        if print_edge_summary:
            _print_edge_summary(edges)
        # info(
        # f'edge creation time {tedge:7.3f} num splices ' +
        # str([e.total_allowed_splices()
        # for e in edges]) + ' num exits ' + str([e.len for e in edges])
        # )
        spdb.sync_to_disk()

    toret = SearchSpaceDag(criteria.bbspec, bbs, verts, edges)
    if timing:
        toret = toret, tdb, tvertex, tedge
    return toret
Exemplo n.º 5
0
def simple_search_dag(
    criteria,
    db=None,
    nbblocks=[64],
    min_seg_len=15,
    parallel=False,
    verbosity=0,
    timing=0,
    modbbs=None,
    make_edges=True,
    merge_bblock=None,
    merge_segment=None,
    precache_splices=False,
    precache_only=False,
    bbs=None,
    bblock_ranges=[],
    only_seg=None,
    source=None,
    print_edge_summary=False,
    no_duplicate_bases=False,
    shuffle_bblocks=False,
    use_saved_bblocks=False,
    output_prefix="./worms",
    only_ivertex=[],
    **kw,
):
    bbdb, spdb = db
    queries, directions = zip(*criteria.bbspec)
    tdb = time()
    if bbs is None:
        bbs = list()
        savename = output_prefix + "_bblocks.pickle"

        if use_saved_bblocks and os.path.exists(savename):
            with open(savename, "rb") as inp:
                bbnames_list = _pickle.load(inp)
            # for i, l in enumerate(bbnames_list)
            # if len(l) >= nbblocks[i]:
            # assert 0, f"too many bblocks in {savename}"
            for i, bbnames in enumerate(bbnames_list):
                bbs.append([bbdb.bblock(n) for n in bbnames[:nbblocks[i]]])

        else:
            for iquery, query in enumerate(queries):
                if hasattr(criteria, "cloned_segments"):
                    msegs = [
                        i + len(queries) if i < 0 else i
                        for i in criteria.cloned_segments()
                    ]
                    if iquery in msegs[1:]:
                        print("seg", iquery, "repeating bblocks from",
                              msegs[0])
                        bbs.append(bbs[msegs[0]])
                        continue
                bbs0 = bbdb.query(
                    query,
                    max_bblocks=nbblocks[iquery],
                    shuffle_bblocks=shuffle_bblocks,
                    parallel=parallel,
                )
                bbs.append(bbs0)

            if bblock_ranges:
                bbs_sliced = list()
                assert len(bblock_ranges) == 2 * len(bbs)
                for ibb, bb in enumerate(bbs):
                    lb, ub = bblock_ranges[2 * ibb:2 * ibb + 2]
                    bbs_sliced.append(bb[lb:ub])
                bbs = bbs_sliced

            for ibb, bb in enumerate(bbs):
                print("bblocks", ibb)
                for b in bb:
                    print("   ", bytes(b.file).decode("utf-8"))

        bases = [
            Counter(bytes(b.base).decode("utf-8") for b in bbs0)
            for bbs0 in bbs
        ]
        assert len(bbs) == len(queries)
        for i, v in enumerate(bbs):
            assert len(v) > 0, 'no bblocks for query: "' + queries[i] + '"'
        print("bblock queries:", str(queries))
        print("bblock numbers:", [len(b) for b in bbs])
        print("bblocks id:", [id(b) for b in bbs])
        print("bblock0 id ", [id(b[0]) for b in bbs])
        print("base_counts:")
        for query, basecount in zip(queries, bases):
            counts = " ".join(f"{k}: {c}" for k, c in basecount.items())
            print(f"   {query:10}", counts)

        if criteria.is_cyclic:
            # for a, b in zip(bbs[criteria.from_seg], bbs[criteria.to_seg]):
            # assert a is b
            bbs[criteria.to_seg] = bbs[criteria.from_seg]

        if use_saved_bblocks and not os.path.exists(savename):
            bbnames = [[bytes(b.file).decode("utf-8") for b in bb]
                       for bb in bbs]
            with open(savename, "wb") as out:
                _pickle.dump(bbnames, out)

    else:
        bbs = bbs.copy()

    assert len(bbs) == len(criteria.bbspec)
    if modbbs:
        modbbs(bbs)

    if merge_bblock is not None and merge_bblock >= 0:
        # print('cloned_segments', criteria.bbspec, criteria.cloned_segments())
        if hasattr(criteria, "cloned_segments") and merge_segment is None:
            for i in criteria.cloned_segments():
                # print('   ', 'merge seg', i, 'merge_bblock', merge_bblock)
                bbs[i] = (bbs[i][merge_bblock], )
        else:
            if merge_segment is None:
                merge_segment = 0
            # print('   ', 'merge_segment not None')
            # print('   ', [len(b) for b in bbs])
            # print('   ', 'merge_segment', merge_segment)
            # print('   ', 'merge_bblock', merge_bblock, len(bbs[merge_segment]))
            bbs[merge_segment] = (bbs[merge_segment][merge_bblock], )

    tdb = time() - tdb
    # info(
    # f'bblock creation time {tdb:7.3f} num bbs: ' +
    # str([len(x) for x in bbs])
    # )

    if precache_splices:
        bbnames = [[bytes(bb.file) for bb in bbtup] for bbtup in bbs]
        bbpairs = set()
        # for bb1, bb2, dirn1 in zip(bbnames, bbnames[1:], directions):
        for i in range(len(bbnames) - 1):
            bb1 = bbnames[i]
            bb2 = bbnames[i + 1]
            dirn1 = directions[i]
            rev = dirn1[1] == "N"
            if bbs[i] is bbs[i + 1]:
                bbpairs.update((a, a) for a in bb1)
            else:
                bbpairs.update(
                    (b, a) if rev else (a, b) for a in bb1 for b in bb2)
        precompute_splicedb(db,
                            bbpairs,
                            verbosity=verbosity,
                            parallel=parallel,
                            **kw)
    if precache_only:
        return bbs

    verts = [None] * len(queries)
    edges = [None] * len(queries[1:])
    if source:
        srcdirn = [
            "".join("NC_"[d] for d in source.verts[i].dirn)
            for i in range(len(source.verts))
        ]  # yapf: disable
        srcverts, srcedges = list(), list()
        for i, bb in enumerate(bbs):
            for isrc, bbsrc in enumerate(source.bbs):

                # fragile code... detecting this way can be wrong
                # print(i, isrc, directions[i], srcdirn[isrc])
                if directions[i] != srcdirn[isrc]:
                    continue
                if [b.filehash for b in bb] == [b.filehash for b in bbsrc]:
                    # super hacky fix, really need to be passed info on what's what
                    if srcverts and srcverts[-1] + 1 != isrc:
                        continue
                    verts[i] = source.verts[isrc]
                    srcverts.append(isrc)

        for i, bb in enumerate(zip(bbs, bbs[1:])):
            bb0, bb1 = bb
            for isrc, bbsrc in enumerate(zip(source.bbs, source.bbs[1:])):
                bbsrc0, bbsrc1 = bbsrc
                if directions[i] != srcdirn[isrc]:
                    continue
                if directions[i + 1] != srcdirn[isrc + 1]:
                    continue
                he = [b.filehash for b in bb0] == [b.filehash for b in bbsrc0]
                he &= [b.filehash for b in bb1] == [b.filehash for b in bbsrc1]
                if not he:
                    continue
                edges[i] = source.edges[isrc]
                srcedges.append(isrc)

    if not make_edges:
        edges = []

    tvertex = time()
    exe = InProcessExecutor()

    if parallel:
        exe = cf.ThreadPoolExecutor(max_workers=parallel)
    with exe as pool:
        if only_seg is not None:
            save = bbs, directions
            bbs = [bbs[only_seg]]
            directions = [directions[only_seg]]
            verts = [verts[only_seg]]
        futures = list()
        for i, bb in enumerate(bbs):
            dirn = directions[i]
            if verts[i] is None:
                futures.append(
                    pool.submit(Vertex, bb, dirn, min_seg_len=min_seg_len))
        verts_new = [f.result() for f in futures]
        isnone = [i for i in range(len(verts)) if verts[i] is None]
        for i, inone in enumerate(isnone):
            verts[inone] = verts_new[i]
            if source:
                print('use new vertex', inone)
        if only_ivertex:
            # raise NotImplementedError
            print("!!!!!!! using one ivertex !!!!!", only_ivertex, len(verts),
                  [v.len for v in verts])
            if len(only_ivertex) != len(verts):
                print(
                    "NOT altering verts, len(only_ivertex)!=len(verts) continuing...",
                    "this is ok if part of a sub-protocol")
            else:
                for i, v in enumerate(verts):
                    if v.len > 1:  # could already have been "trimmed"
                        assert only_ivertex[i] < v.len
                        v.reduce_to_only_one_inplace(only_ivertex[i])
                    # print('x2exit', v.x2exit.shape)
                    # print('x2orig', v.x2orig.shape)
                    # print('ires', v.ires.shape)
                    # print('isite', v.isite.shape)
                    # print('ichain', v.ichain.shape)
                    # print('ibblock', v.ibblock.shape)
                    # print('inout', v.inout.shape, v.inout[10:])
                    # print('inbreaks', v.inbreaks.shape, v.inbreaks[10:])
                    # print('dirn', v.dirn.shape)
                    # # assert 0
        # print(i, len(verts_new), len(verts))
        if isnone:
            assert i + 1 == len(verts_new)
        assert all(v for v in verts)
        if only_seg is not None:
            verts = [None] * only_seg + verts + [None] * (len(queries) -
                                                          only_seg - 1)
            bbs, directions = save
    tvertex = time() - tvertex
    # info(
    # f'vertex creation time {tvertex:7.3f} num verts ' +
    # str([v.len if v else 0 for v in verts])
    # )

    if make_edges:
        tedge = time()
        for i, e in enumerate(edges):
            if e is not None:
                continue
            edges[i], edge_analysis = Edge(
                verts[i],
                bbs[i],
                verts[i + 1],
                bbs[i + 1],
                splicedb=spdb,
                verbosity=verbosity,
                precache_splices=precache_splices,
                **kw,
            )
            allok = all(x[6] for x in edge_analysis)
            if allok:
                continue
            print("=" * 80)
            print("info for edges with no valid splices",
                  edges[i].total_allowed_splices())
            for tup in edge_analysis:
                iblk0, iblk1, ofst0, ofst1, ires0, ires1 = tup[:6]
                ok, f_clash, f_rms, f_ncontact, f_ncnh, f_nhc = tup[6:12]
                m_rms, m_ncontact, m_ncnh, m_nhc = tup[12:]
                if ok:
                    continue
                assert len(bbs[i + 0]) > iblk0
                assert len(bbs[i + 1]) > iblk1
                print("=" * 80)
                print("egde Bblock A", bytes(bbs[i][iblk0].file))
                print("egde Bblock B", bytes(bbs[i + 1][iblk1].file))
                print(
                    f"bb {iblk0:3} {iblk1:3}",
                    f"ofst {ofst0:4} {ofst1:4}",
                    f"resi {ires0.shape} {ires1.shape}",
                )
                print(
                    f"clash_ok {int(f_clash*100):3}%",
                    f"rms_ok {int(f_rms*100):3}%",
                    f"ncontact_ok {int(f_ncontact*100):3}%",
                    f"ncnh_ok {int(f_ncnh*100):3}%",
                    f"nhc_ok {int(f_nhc*100):3}%",
                )
                print(
                    f"min_rms {m_rms:7.3f}",
                    f"max_ncontact {m_ncontact:7.3f}",
                    f"max_ncnh {m_ncnh:7.3f}",
                    f"max_nhc {m_nhc:7.3f}",
                )
            print("=" * 80)
            fok = np.stack([x[7:12] for x in edge_analysis]).mean(axis=0)
            rmsmin = np.array([x[12] for x in edge_analysis]).min()
            fmx = np.stack([x[13:] for x in edge_analysis]).max(axis=0)
            print(f"{' SPLICE FAIL SUMMARY ':=^80}")
            print(f"splice clash ok               {int(fok[0]*100):3}%")
            print(f"splice rms ok                 {int(fok[1]*100):3}%")
            print(f"splice ncontacts ok           {int(fok[2]*100):3}%")
            print(f"splice ncontacts_no_helix ok  {int(fok[3]*100):3}%")
            print(f"splice nhelixcontacted ok     {int(fok[4]*100):3}%")
            print(f"min rms of any failing        {rmsmin}")
            print(
                f"max ncontact of any failing   {fmx[0]} (maybe large for non-5-helix splice)"
            )
            print(
                f"max ncontact_no_helix         {fmx[1]} (will be 999 for non-5-helix splice)"
            )
            print(
                f"max nhelix_contacted          {fmx[2]} (will be 999 for non-5-helix splice)"
            )
            print("=" * 80)
            assert edges[i].total_allowed_splices() > 0, "invalid splice"
        tedge = time() - tedge
        if print_edge_summary:
            _print_edge_summary(edges)
        # info(
        # f'edge creation time {tedge:7.3f} num splices ' +
        # str([e.total_allowed_splices()
        # for e in edges]) + ' num exits ' + str([e.len for e in edges])
        # )
        spdb.sync_to_disk()

    toret = SearchSpaceDag(criteria.bbspec, bbs, verts, edges)
    if timing:
        toret = toret, tdb, tvertex, tedge
    return toret
Exemplo n.º 6
0
   nout = sum(len(a) for a in outblk_res.values())
   nent = sum(len(a) for a in inblk_res.values())
   valid_splices = [list() for i in range(nout)]

   swapped = False
   if u.dirn[1] == 0:  # swap so N-to-C!
      swapped = True
      u, ublks, v, vblks = v, vblks, u, ublks
      outblk_res, inblk_res = inblk_res, outblk_res
      outblk, inblk = inblk, outblk

   pairs_with_no_valid_splices = 0
   bblock_pair_analysis = list()
   tcache = 0

   exe = InProcessExecutor()
   if parallel:
      exe = cf.ProcessPoolExecutor(max_workers=parallel)
   # exe = cf.ThreadPoolExecutor(max_workers=parallel) if parallel else InProcessExecutor()
   with exe as pool:
      futures = list()
      ofst0 = 0
      for iblk0, ires0 in outblk_res.items():
         blk0 = ublks[iblk0]
         key0 = blk0.filehash
         t = time()
         cache = splicedb.partial(params, key0) if splicedb else None
         tcache += time() - t
         ofst1 = 0
         for iblk1, ires1 in inblk_res.items():
            blk1 = vblks[iblk1]