def compute_splices(bbdb, bbpairs, verbosity, parallel, pbar, pbar_interval=10.0, **kw): bbpairs_shuf = bbpairs.copy() shuffle(bbpairs_shuf) exe = InProcessExecutor() if parallel: exe = cf.ProcessPoolExecutor(max_workers=parallel) with exe as pool: futures = list() for bbpair in bbpairs_shuf: bbw0 = BBlockWrap(bbdb.bblock(bbpair[0])) bbw1 = BBlockWrap(bbdb.bblock(bbpair[1])) f = pool.submit(_valid_splice_pairs, bbw0, bbw1, **kw) f.bbpair = bbpair futures.append(f) print("batch compute_splices, npairs:", len(futures)) fiter = cf.as_completed(futures) if pbar: fiter = tqdm(fiter, "precache splices", mininterval=pbar_interval, total=len(futures)) res = {f.bbpair: f.result() for f in fiter} return {bbpair: res[bbpair] for bbpair in bbpairs}
def get_allowed_splices( u, ublks, v, vblks, splicedb=None, splice_max_rms=0.7, splice_ncontact_cut=30, splice_clash_d2=4.0**2, # ca only splice_contact_d2=8.0**2, splice_rms_range=6, splice_clash_contact_range=60, splice_clash_contact_by_helix=True, splice_ncontact_no_helix_cut=0, splice_nhelix_contacted_cut=0, splice_max_chain_length=999999, skip_on_fail=True, parallel=False, verbosity=1, cache_sync=0.001, precache_splices=False, pbar=False, pbar_interval=10.0, **kw): assert (u.dirn[1] + v.dirn[0]) == 1, 'get_allowed_splices dirn mismatch' # note: this is duplicated in edge_batch.py and they need to be the same params = (splice_max_rms, splice_ncontact_cut, splice_clash_d2, splice_contact_d2, splice_rms_range, splice_clash_contact_range, splice_clash_contact_by_helix, splice_ncontact_no_helix_cut, splice_nhelix_contacted_cut, splice_max_chain_length) outidx = _get_outidx(u.inout[:, 1]) outblk = u.ibblock[outidx] outres = u.ires[outidx, 1] inblk = v.ibblock[v.inbreaks[:-1]] inres = v.ires[v.inbreaks[:-1], 0] inblk_breaks = contig_idx_breaks(inblk) outblk_res = defaultdict(list) for iblk, ires in zip(outblk, outres): outblk_res[iblk].append(ires) for iblk in outblk_res.keys(): outblk_res[iblk] = np.array(outblk_res[iblk], 'i4') inblk_res = defaultdict(list) for iblk, ires in zip(inblk, inres): inblk_res[iblk].append(ires) for iblk in inblk_res.keys(): inblk_res[iblk] = np.array(inblk_res[iblk], 'i4') assert np.all(sorted(inblk_res[iblk]) == inblk_res[iblk]) nout = sum(len(a) for a in outblk_res.values()) nent = sum(len(a) for a in inblk_res.values()) valid_splices = [list() for i in range(nout)] swapped = False if u.dirn[1] == 0: # swap so N-to-C! swapped = True u, ublks, v, vblks = v, vblks, u, ublks outblk_res, inblk_res = inblk_res, outblk_res outblk, inblk = inblk, outblk pairs_with_no_valid_splices = 0 tcache = 0 exe = InProcessExecutor() if parallel: exe = cf.ProcessPoolExecutor(max_workers=parallel) # exe = cf.ThreadPoolExecutor(max_workers=parallel) if parallel else InProcessExecutor() with exe as pool: futures = list() ofst0 = 0 for iblk0, ires0 in outblk_res.items(): blk0 = ublks[iblk0] key0 = blk0.filehash t = time() cache = splicedb.partial(params, key0) if splicedb else None tcache += time() - t ofst1 = 0 for iblk1, ires1 in inblk_res.items(): blk1 = vblks[iblk1] key1 = blk1.filehash if cache and key1 in cache and cache[key1]: splices = cache[key1] future = NonFuture(splices, dummy=True) else: future = pool.submit( _jit_splice_metrics, blk0.chains, blk1.chains, blk0.ncac, blk1.ncac, blk0.stubs, blk1.stubs, blk0.connections, blk1.connections, blk0.ss, blk1.ss, blk0.cb, blk1.cb, splice_clash_d2, splice_contact_d2, splice_rms_range, splice_clash_contact_range, splice_clash_contact_by_helix, splice_max_rms, splice_max_chain_length, skip_on_fail) fs = (iblk0, iblk1, ofst0, ofst1, ires0, ires1) future.stash = fs futures.append(future) ofst1 += len(ires1) ofst0 += len(ires0) if verbosity > 0 and tcache > 1.0: print('get_allowed_splices read caches time:', tcache) future_iter = cf.as_completed(futures) if pbar and not precache_splices: future_iter = tqdm(cf.as_completed(futures), 'checking splices', mininterval=pbar_interval, total=len(futures)) for future in future_iter: iblk0, iblk1, ofst0, ofst1, ires0, ires1 = future.stash result = future.result() if len(result) is 5 and isinstance(result[0], np.ndarray): # is newly computed result, not from cache rms, nclash, ncontact, ncnh, nhc = result ok = ((nclash == 0) * (rms <= splice_max_rms) * (ncontact >= splice_ncontact_cut) * (ncnh >= splice_ncontact_no_helix_cut) * (nhc >= splice_nhelix_contacted_cut)) result = _splice_respairs(ok, ublks[iblk0], vblks[iblk1]) if np.sum(ok) == 0: print('N no clash', np.sum(nclash == 0)) print('N rms', np.sum(rms <= splice_max_rms)) print('N contact', np.sum(ncontact >= splice_ncontact_cut)) if splicedb: key0 = ublks[iblk0].filehash # C-term side key1 = vblks[iblk1].filehash # N-term side splicedb.add(params, key0, key1, result) if np.random.random() < cache_sync: print('sync_to_disk splices data') splicedb.sync_to_disk() if swapped: result = result[1], result[0] ires0, ires1 = ires1, ires0 ofst0, ofst1 = ofst1, ofst0 if len(result[0]) == 0: pairs_with_no_valid_splices += 1 continue index_of_ires0 = _index_of_map(ires0, np.max(result[0])) index_of_ires1 = _index_of_map(ires1, np.max(result[1])) irs = index_of_ires0[result[0]] jrs = index_of_ires1[result[1]] ok = (irs >= 0) * (jrs >= 0) irs = irs[ok] + ofst0 jrs = jrs[ok] + ofst1 for ir, jr in zip(irs, jrs): valid_splices[ir].append(jr) if cache_sync > 0 and splicedb: splicedb.sync_to_disk() if pairs_with_no_valid_splices: print('pairs with no valid splices: ', pairs_with_no_valid_splices, 'of', len(outblk_res) * len(inblk_res)) return valid_splices, nout, nent
def grow_linear(ssdag, loss_function=null_lossfunc, loss_threshold=2.0, last_bb_same_as=-1, parallel=0, monte_carlo=0, verbosity=0, merge_bblock=None, lbl='', pbar=False, pbar_interval=10.0, no_duplicate_bases=True, max_linear=1000000, **kw): verts = ssdag.verts edges = ssdag.edges if last_bb_same_as is None: last_bb_same_as = -1 assert len(verts) > 1 assert len(verts) == len(edges) + 1 assert verts[0].dirn[0] == 2 assert verts[-1].dirn[1] == 2 for ivertex in range(len(verts) - 1): assert verts[ivertex].dirn[1] + verts[ivertex + 1].dirn[0] == 1 # if isinstance(loss_function, types.FunctionType): # if not 'NUMBA_DISABLE_JIT' in os.environ: # loss_function = nb.njit(nogil=1, fastmath=1) exe = cf.ThreadPoolExecutor( max_workers=parallel) if parallel else InProcessExecutor() # exe = cf.ProcessPoolExecutor(max_workers=parallel) if parallel else InProcessExecutor() with exe as pool: bb_base = tuple([ np.array([b.basehash if no_duplicate_bases else 0 for b in bb], dtype=np.int64) for bb in ssdag.bbs ]) verts_pickleable = [v._state for v in verts] edges_pickleable = [e._state for e in edges] kwargs = dict( bb_base=bb_base, verts_pickleable=verts_pickleable, edges_pickleable=edges_pickleable, loss_function=loss_function, loss_threshold=loss_threshold, last_bb_same_as=last_bb_same_as, nresults=0, isplice=0, splice_position=np.eye(4, dtype=vertex_xform_dtype), max_linear=max_linear, ) futures = list() if monte_carlo: kwargs['fn'] = _grow_linear_mc_start kwargs['seconds'] = monte_carlo kwargs['ivertex_range'] = (0, verts[0].len) kwargs['merge_bblock'] = merge_bblock kwargs['lbl'] = lbl kwargs['verbosity'] = verbosity kwargs['pbar'] = pbar kwargs['pbar_interval'] = pbar_interval njob = cpu_count() if parallel else 1 for ivert in range(njob): kwargs['threadno'] = ivert futures.append(pool.submit(**kwargs)) else: kwargs['fn'] = _grow_linear_start nbatch = max(1, int(verts[0].len / 64 / cpu_count())) for ivert in range(0, verts[0].len, nbatch): ivert_end = min(verts[0].len, ivert + nbatch) kwargs['ivertex_range'] = ivert, ivert_end futures.append(pool.submit(**kwargs)) results = list() if monte_carlo: for f in cf.as_completed(futures): results.append(f.result()) else: desc = 'linear search ' + str(lbl) if merge_bblock is None: merge_bblock = 0 fiter = cf.as_completed(futures) if pbar: fiter = tqdm(fiter, desc=desc, position=merge_bblock + 1, mininterval=pbar_interval, total=len(futures)) for f in fiter: results.append(f.result()) tot_stats = zero_search_stats() for i in range(len(tot_stats)): tot_stats[i][0] += sum([r.stats[i][0] for r in results]) result = ResultJIT(pos=np.concatenate([r.pos for r in results]), idx=np.concatenate([r.idx for r in results]), err=np.concatenate([r.err for r in results]), stats=tot_stats) result = remove_duplicate_results(result) order = np.argsort(result.err) return ResultJIT(pos=result.pos[order], idx=result.idx[order], err=result.err[order], stats=result.stats)
def simple_search_dag(criteria, db=None, nbblocks=100, min_seg_len=15, parallel=False, verbosity=0, timing=0, modbbs=None, make_edges=True, merge_bblock=None, precache_splices=False, precache_only=False, bbs=None, only_seg=None, source=None, print_edge_summary=False, no_duplicate_bases=False, shuffle_bblocks=False, use_saved_bblocks=False, output_prefix='./worms', **kw): bbdb, spdb = db queries, directions = zip(*criteria.bbspec) tdb = time() if bbs is None: bbs = list() savename = output_prefix + '_bblocks.pickle' if use_saved_bblocks and os.path.exists(savename): with open(savename, 'rb') as inp: bbnames_list = _pickle.load(inp) for bbnames in bbnames_list: bbs.append([bbdb.bblock(n) for n in bbnames]) else: for iquery, query in enumerate(queries): msegs = [ i + len(queries) if i < 0 else i for i in criteria.which_mergeseg() ] if iquery in msegs[1:]: print('seg', iquery, 'repeating bblocks from', msegs[0]) bbs.append(bbs[msegs[0]]) continue bbs0 = bbdb.query( query, max_bblocks=nbblocks, shuffle_bblocks=shuffle_bblocks, parallel=parallel, ) bbs.append(bbs0) bases = [ Counter(bytes(b.base).decode('utf-8') for b in bbs0) for bbs0 in bbs ] assert len(bbs) == len(queries) for i, v in enumerate(bbs): assert len(v) > 0, 'no bblocks for query: "' + queries[i] + '"' print('bblock queries:', str(queries)) print('bblock numbers:', [len(b) for b in bbs]) print('bblocks id:', [id(b) for b in bbs]) print('bblock0 id ', [id(b[0]) for b in bbs]) print('base_counts:') for query, basecount in zip(queries, bases): counts = ' '.join(f'{k}: {c}' for k, c in basecount.items()) print(f' {query:10}', counts) if criteria.is_cyclic: for a, b in zip(bbs[criteria.from_seg], bbs[criteria.to_seg]): assert a is b bbs[criteria.to_seg] = bbs[criteria.from_seg] if use_saved_bblocks: bbnames = [[bytes(b.file).decode('utf-8') for b in bb] for bb in bbs] with open(savename, 'wb') as out: _pickle.dump(bbnames, out) else: bbs = bbs.copy() assert len(bbs) == len(criteria.bbspec) if modbbs: modbbs(bbs) if merge_bblock is not None and merge_bblock >= 0: # print('which_mergeseg', criteria.bbspec, criteria.which_mergeseg()) for i in criteria.which_mergeseg(): bbs[i] = (bbs[i][merge_bblock], ) tdb = time() - tdb # info( # f'bblock creation time {tdb:7.3f} num bbs: ' + # str([len(x) for x in bbs]) # ) if precache_splices: bbnames = [[bytes(bb.file) for bb in bbtup] for bbtup in bbs] bbpairs = set() # for bb1, bb2, dirn1 in zip(bbnames, bbnames[1:], directions): for i in range(len(bbnames) - 1): bb1 = bbnames[i] bb2 = bbnames[i + 1] dirn1 = directions[i] rev = dirn1[1] == 'N' if bbs[i] is bbs[i + 1]: bbpairs.update((a, a) for a in bb1) else: bbpairs.update( (b, a) if rev else (a, b) for a in bb1 for b in bb2) precompute_splicedb(db, bbpairs, verbosity=verbosity, parallel=parallel, **kw) if precache_only: return bbs verts = [None] * len(queries) edges = [None] * len(queries[1:]) if source: srcdirn = [''.join('NC_' [d] for d in source.verts[i].dirn) for i in range(len(source.verts))] # yapf: disable srcverts, srcedges = list(), list() for i, bb in enumerate(bbs): for isrc, bbsrc in enumerate(source.bbs): if directions[i] != srcdirn[isrc]: continue if [b.filehash for b in bb] == [b.filehash for b in bbsrc]: verts[i] = source.verts[isrc] srcverts.append(isrc) for i, bb in enumerate(zip(bbs, bbs[1:])): bb0, bb1 = bb for isrc, bbsrc in enumerate(zip(source.bbs, source.bbs[1:])): bbsrc0, bbsrc1 = bbsrc if directions[i] != srcdirn[isrc]: continue if directions[i + 1] != srcdirn[isrc + 1]: continue he = [b.filehash for b in bb0] == [b.filehash for b in bbsrc0] he &= [b.filehash for b in bb1] == [b.filehash for b in bbsrc1] if not he: continue edges[i] = source.edges[isrc] srcedges.append(isrc) if not make_edges: edges = [] tvertex = time() exe = InProcessExecutor() if parallel: exe = cf.ThreadPoolExecutor(max_workers=parallel) with exe as pool: if only_seg is not None: save = bbs, directions bbs = [bbs[only_seg]] directions = [directions[only_seg]] verts = [verts[only_seg]] futures = list() for i, bb in enumerate(bbs): dirn = directions[i] if verts[i] is None: futures.append( pool.submit(Vertex, bb, dirn, min_seg_len=min_seg_len)) verts_new = [f.result() for f in futures] isnone = [i for i in range(len(verts)) if verts[i] is None] for i, inone in enumerate(isnone): verts[inone] = verts_new[i] # print(i, len(verts_new), len(verts)) if isnone: assert i + 1 == len(verts_new) assert all(v for v in verts) if only_seg is not None: verts = ([None] * only_seg + verts + [None] * (len(queries) - only_seg - 1)) bbs, directions = save tvertex = time() - tvertex # info( # f'vertex creation time {tvertex:7.3f} num verts ' + # str([v.len if v else 0 for v in verts]) # ) if make_edges: tedge = time() for i, e in enumerate(edges): if e is None: edges[i] = Edge(verts[i], bbs[i], verts[i + 1], bbs[i + 1], splicedb=spdb, verbosity=verbosity, precache_splices=precache_splices, **kw) tedge = time() - tedge if print_edge_summary: _print_edge_summary(edges) # info( # f'edge creation time {tedge:7.3f} num splices ' + # str([e.total_allowed_splices() # for e in edges]) + ' num exits ' + str([e.len for e in edges]) # ) spdb.sync_to_disk() toret = SearchSpaceDag(criteria.bbspec, bbs, verts, edges) if timing: toret = toret, tdb, tvertex, tedge return toret
def simple_search_dag( criteria, db=None, nbblocks=[64], min_seg_len=15, parallel=False, verbosity=0, timing=0, modbbs=None, make_edges=True, merge_bblock=None, merge_segment=None, precache_splices=False, precache_only=False, bbs=None, bblock_ranges=[], only_seg=None, source=None, print_edge_summary=False, no_duplicate_bases=False, shuffle_bblocks=False, use_saved_bblocks=False, output_prefix="./worms", only_ivertex=[], **kw, ): bbdb, spdb = db queries, directions = zip(*criteria.bbspec) tdb = time() if bbs is None: bbs = list() savename = output_prefix + "_bblocks.pickle" if use_saved_bblocks and os.path.exists(savename): with open(savename, "rb") as inp: bbnames_list = _pickle.load(inp) # for i, l in enumerate(bbnames_list) # if len(l) >= nbblocks[i]: # assert 0, f"too many bblocks in {savename}" for i, bbnames in enumerate(bbnames_list): bbs.append([bbdb.bblock(n) for n in bbnames[:nbblocks[i]]]) else: for iquery, query in enumerate(queries): if hasattr(criteria, "cloned_segments"): msegs = [ i + len(queries) if i < 0 else i for i in criteria.cloned_segments() ] if iquery in msegs[1:]: print("seg", iquery, "repeating bblocks from", msegs[0]) bbs.append(bbs[msegs[0]]) continue bbs0 = bbdb.query( query, max_bblocks=nbblocks[iquery], shuffle_bblocks=shuffle_bblocks, parallel=parallel, ) bbs.append(bbs0) if bblock_ranges: bbs_sliced = list() assert len(bblock_ranges) == 2 * len(bbs) for ibb, bb in enumerate(bbs): lb, ub = bblock_ranges[2 * ibb:2 * ibb + 2] bbs_sliced.append(bb[lb:ub]) bbs = bbs_sliced for ibb, bb in enumerate(bbs): print("bblocks", ibb) for b in bb: print(" ", bytes(b.file).decode("utf-8")) bases = [ Counter(bytes(b.base).decode("utf-8") for b in bbs0) for bbs0 in bbs ] assert len(bbs) == len(queries) for i, v in enumerate(bbs): assert len(v) > 0, 'no bblocks for query: "' + queries[i] + '"' print("bblock queries:", str(queries)) print("bblock numbers:", [len(b) for b in bbs]) print("bblocks id:", [id(b) for b in bbs]) print("bblock0 id ", [id(b[0]) for b in bbs]) print("base_counts:") for query, basecount in zip(queries, bases): counts = " ".join(f"{k}: {c}" for k, c in basecount.items()) print(f" {query:10}", counts) if criteria.is_cyclic: # for a, b in zip(bbs[criteria.from_seg], bbs[criteria.to_seg]): # assert a is b bbs[criteria.to_seg] = bbs[criteria.from_seg] if use_saved_bblocks and not os.path.exists(savename): bbnames = [[bytes(b.file).decode("utf-8") for b in bb] for bb in bbs] with open(savename, "wb") as out: _pickle.dump(bbnames, out) else: bbs = bbs.copy() assert len(bbs) == len(criteria.bbspec) if modbbs: modbbs(bbs) if merge_bblock is not None and merge_bblock >= 0: # print('cloned_segments', criteria.bbspec, criteria.cloned_segments()) if hasattr(criteria, "cloned_segments") and merge_segment is None: for i in criteria.cloned_segments(): # print(' ', 'merge seg', i, 'merge_bblock', merge_bblock) bbs[i] = (bbs[i][merge_bblock], ) else: if merge_segment is None: merge_segment = 0 # print(' ', 'merge_segment not None') # print(' ', [len(b) for b in bbs]) # print(' ', 'merge_segment', merge_segment) # print(' ', 'merge_bblock', merge_bblock, len(bbs[merge_segment])) bbs[merge_segment] = (bbs[merge_segment][merge_bblock], ) tdb = time() - tdb # info( # f'bblock creation time {tdb:7.3f} num bbs: ' + # str([len(x) for x in bbs]) # ) if precache_splices: bbnames = [[bytes(bb.file) for bb in bbtup] for bbtup in bbs] bbpairs = set() # for bb1, bb2, dirn1 in zip(bbnames, bbnames[1:], directions): for i in range(len(bbnames) - 1): bb1 = bbnames[i] bb2 = bbnames[i + 1] dirn1 = directions[i] rev = dirn1[1] == "N" if bbs[i] is bbs[i + 1]: bbpairs.update((a, a) for a in bb1) else: bbpairs.update( (b, a) if rev else (a, b) for a in bb1 for b in bb2) precompute_splicedb(db, bbpairs, verbosity=verbosity, parallel=parallel, **kw) if precache_only: return bbs verts = [None] * len(queries) edges = [None] * len(queries[1:]) if source: srcdirn = [ "".join("NC_"[d] for d in source.verts[i].dirn) for i in range(len(source.verts)) ] # yapf: disable srcverts, srcedges = list(), list() for i, bb in enumerate(bbs): for isrc, bbsrc in enumerate(source.bbs): # fragile code... detecting this way can be wrong # print(i, isrc, directions[i], srcdirn[isrc]) if directions[i] != srcdirn[isrc]: continue if [b.filehash for b in bb] == [b.filehash for b in bbsrc]: # super hacky fix, really need to be passed info on what's what if srcverts and srcverts[-1] + 1 != isrc: continue verts[i] = source.verts[isrc] srcverts.append(isrc) for i, bb in enumerate(zip(bbs, bbs[1:])): bb0, bb1 = bb for isrc, bbsrc in enumerate(zip(source.bbs, source.bbs[1:])): bbsrc0, bbsrc1 = bbsrc if directions[i] != srcdirn[isrc]: continue if directions[i + 1] != srcdirn[isrc + 1]: continue he = [b.filehash for b in bb0] == [b.filehash for b in bbsrc0] he &= [b.filehash for b in bb1] == [b.filehash for b in bbsrc1] if not he: continue edges[i] = source.edges[isrc] srcedges.append(isrc) if not make_edges: edges = [] tvertex = time() exe = InProcessExecutor() if parallel: exe = cf.ThreadPoolExecutor(max_workers=parallel) with exe as pool: if only_seg is not None: save = bbs, directions bbs = [bbs[only_seg]] directions = [directions[only_seg]] verts = [verts[only_seg]] futures = list() for i, bb in enumerate(bbs): dirn = directions[i] if verts[i] is None: futures.append( pool.submit(Vertex, bb, dirn, min_seg_len=min_seg_len)) verts_new = [f.result() for f in futures] isnone = [i for i in range(len(verts)) if verts[i] is None] for i, inone in enumerate(isnone): verts[inone] = verts_new[i] if source: print('use new vertex', inone) if only_ivertex: # raise NotImplementedError print("!!!!!!! using one ivertex !!!!!", only_ivertex, len(verts), [v.len for v in verts]) if len(only_ivertex) != len(verts): print( "NOT altering verts, len(only_ivertex)!=len(verts) continuing...", "this is ok if part of a sub-protocol") else: for i, v in enumerate(verts): if v.len > 1: # could already have been "trimmed" assert only_ivertex[i] < v.len v.reduce_to_only_one_inplace(only_ivertex[i]) # print('x2exit', v.x2exit.shape) # print('x2orig', v.x2orig.shape) # print('ires', v.ires.shape) # print('isite', v.isite.shape) # print('ichain', v.ichain.shape) # print('ibblock', v.ibblock.shape) # print('inout', v.inout.shape, v.inout[10:]) # print('inbreaks', v.inbreaks.shape, v.inbreaks[10:]) # print('dirn', v.dirn.shape) # # assert 0 # print(i, len(verts_new), len(verts)) if isnone: assert i + 1 == len(verts_new) assert all(v for v in verts) if only_seg is not None: verts = [None] * only_seg + verts + [None] * (len(queries) - only_seg - 1) bbs, directions = save tvertex = time() - tvertex # info( # f'vertex creation time {tvertex:7.3f} num verts ' + # str([v.len if v else 0 for v in verts]) # ) if make_edges: tedge = time() for i, e in enumerate(edges): if e is not None: continue edges[i], edge_analysis = Edge( verts[i], bbs[i], verts[i + 1], bbs[i + 1], splicedb=spdb, verbosity=verbosity, precache_splices=precache_splices, **kw, ) allok = all(x[6] for x in edge_analysis) if allok: continue print("=" * 80) print("info for edges with no valid splices", edges[i].total_allowed_splices()) for tup in edge_analysis: iblk0, iblk1, ofst0, ofst1, ires0, ires1 = tup[:6] ok, f_clash, f_rms, f_ncontact, f_ncnh, f_nhc = tup[6:12] m_rms, m_ncontact, m_ncnh, m_nhc = tup[12:] if ok: continue assert len(bbs[i + 0]) > iblk0 assert len(bbs[i + 1]) > iblk1 print("=" * 80) print("egde Bblock A", bytes(bbs[i][iblk0].file)) print("egde Bblock B", bytes(bbs[i + 1][iblk1].file)) print( f"bb {iblk0:3} {iblk1:3}", f"ofst {ofst0:4} {ofst1:4}", f"resi {ires0.shape} {ires1.shape}", ) print( f"clash_ok {int(f_clash*100):3}%", f"rms_ok {int(f_rms*100):3}%", f"ncontact_ok {int(f_ncontact*100):3}%", f"ncnh_ok {int(f_ncnh*100):3}%", f"nhc_ok {int(f_nhc*100):3}%", ) print( f"min_rms {m_rms:7.3f}", f"max_ncontact {m_ncontact:7.3f}", f"max_ncnh {m_ncnh:7.3f}", f"max_nhc {m_nhc:7.3f}", ) print("=" * 80) fok = np.stack([x[7:12] for x in edge_analysis]).mean(axis=0) rmsmin = np.array([x[12] for x in edge_analysis]).min() fmx = np.stack([x[13:] for x in edge_analysis]).max(axis=0) print(f"{' SPLICE FAIL SUMMARY ':=^80}") print(f"splice clash ok {int(fok[0]*100):3}%") print(f"splice rms ok {int(fok[1]*100):3}%") print(f"splice ncontacts ok {int(fok[2]*100):3}%") print(f"splice ncontacts_no_helix ok {int(fok[3]*100):3}%") print(f"splice nhelixcontacted ok {int(fok[4]*100):3}%") print(f"min rms of any failing {rmsmin}") print( f"max ncontact of any failing {fmx[0]} (maybe large for non-5-helix splice)" ) print( f"max ncontact_no_helix {fmx[1]} (will be 999 for non-5-helix splice)" ) print( f"max nhelix_contacted {fmx[2]} (will be 999 for non-5-helix splice)" ) print("=" * 80) assert edges[i].total_allowed_splices() > 0, "invalid splice" tedge = time() - tedge if print_edge_summary: _print_edge_summary(edges) # info( # f'edge creation time {tedge:7.3f} num splices ' + # str([e.total_allowed_splices() # for e in edges]) + ' num exits ' + str([e.len for e in edges]) # ) spdb.sync_to_disk() toret = SearchSpaceDag(criteria.bbspec, bbs, verts, edges) if timing: toret = toret, tdb, tvertex, tedge return toret
nout = sum(len(a) for a in outblk_res.values()) nent = sum(len(a) for a in inblk_res.values()) valid_splices = [list() for i in range(nout)] swapped = False if u.dirn[1] == 0: # swap so N-to-C! swapped = True u, ublks, v, vblks = v, vblks, u, ublks outblk_res, inblk_res = inblk_res, outblk_res outblk, inblk = inblk, outblk pairs_with_no_valid_splices = 0 bblock_pair_analysis = list() tcache = 0 exe = InProcessExecutor() if parallel: exe = cf.ProcessPoolExecutor(max_workers=parallel) # exe = cf.ThreadPoolExecutor(max_workers=parallel) if parallel else InProcessExecutor() with exe as pool: futures = list() ofst0 = 0 for iblk0, ires0 in outblk_res.items(): blk0 = ublks[iblk0] key0 = blk0.filehash t = time() cache = splicedb.partial(params, key0) if splicedb else None tcache += time() - t ofst1 = 0 for iblk1, ires1 in inblk_res.items(): blk1 = vblks[iblk1]