def _grow_linear_start(bb_base, verts_pickleable, edges_pickleable, **kwargs): verts = tuple([_Vertex(*vp) for vp in verts_pickleable]) edges = tuple([_Edge(*ep) for ep in edges_pickleable]) pos = np.empty(shape=(1024, len(verts), 4, 4), dtype=np.float32) idx = np.empty(shape=(1024, len(verts)), dtype=np.int32) err = np.empty(shape=(1024, ), dtype=np.float32) stats = zero_search_stats() result = ResultJIT(pos=pos, idx=idx, err=err, stats=stats) bases = np.zeros(len(verts), dtype=np.int64) nresults, result = _grow_linear_recurse(result=result, bb_base=bb_base, verts=verts, edges=edges, bases=bases, **kwargs) result = ResultJIT(result.pos[:nresults], result.idx[:nresults], result.err[:nresults], result.stats) return result
def grow_linear(ssdag, loss_function=null_lossfunc, loss_threshold=2.0, last_bb_same_as=-1, parallel=0, monte_carlo=0, verbosity=0, merge_bblock=None, lbl='', pbar=False, pbar_interval=10.0, no_duplicate_bases=True, max_linear=1000000, **kw): verts = ssdag.verts edges = ssdag.edges if last_bb_same_as is None: last_bb_same_as = -1 assert len(verts) > 1 assert len(verts) == len(edges) + 1 assert verts[0].dirn[0] == 2 assert verts[-1].dirn[1] == 2 for ivertex in range(len(verts) - 1): assert verts[ivertex].dirn[1] + verts[ivertex + 1].dirn[0] == 1 # if isinstance(loss_function, types.FunctionType): # if not 'NUMBA_DISABLE_JIT' in os.environ: # loss_function = nb.njit(nogil=1, fastmath=1) exe = cf.ThreadPoolExecutor( max_workers=parallel) if parallel else InProcessExecutor() # exe = cf.ProcessPoolExecutor(max_workers=parallel) if parallel else InProcessExecutor() with exe as pool: bb_base = tuple([ np.array([b.basehash if no_duplicate_bases else 0 for b in bb], dtype=np.int64) for bb in ssdag.bbs ]) verts_pickleable = [v._state for v in verts] edges_pickleable = [e._state for e in edges] kwargs = dict( bb_base=bb_base, verts_pickleable=verts_pickleable, edges_pickleable=edges_pickleable, loss_function=loss_function, loss_threshold=loss_threshold, last_bb_same_as=last_bb_same_as, nresults=0, isplice=0, splice_position=np.eye(4, dtype=vertex_xform_dtype), max_linear=max_linear, ) futures = list() if monte_carlo: kwargs['fn'] = _grow_linear_mc_start kwargs['seconds'] = monte_carlo kwargs['ivertex_range'] = (0, verts[0].len) kwargs['merge_bblock'] = merge_bblock kwargs['lbl'] = lbl kwargs['verbosity'] = verbosity kwargs['pbar'] = pbar kwargs['pbar_interval'] = pbar_interval njob = cpu_count() if parallel else 1 for ivert in range(njob): kwargs['threadno'] = ivert futures.append(pool.submit(**kwargs)) else: kwargs['fn'] = _grow_linear_start nbatch = max(1, int(verts[0].len / 64 / cpu_count())) for ivert in range(0, verts[0].len, nbatch): ivert_end = min(verts[0].len, ivert + nbatch) kwargs['ivertex_range'] = ivert, ivert_end futures.append(pool.submit(**kwargs)) results = list() if monte_carlo: for f in cf.as_completed(futures): results.append(f.result()) else: desc = 'linear search ' + str(lbl) if merge_bblock is None: merge_bblock = 0 fiter = cf.as_completed(futures) if pbar: fiter = tqdm(fiter, desc=desc, position=merge_bblock + 1, mininterval=pbar_interval, total=len(futures)) for f in fiter: results.append(f.result()) tot_stats = zero_search_stats() for i in range(len(tot_stats)): tot_stats[i][0] += sum([r.stats[i][0] for r in results]) result = ResultJIT(pos=np.concatenate([r.pos for r in results]), idx=np.concatenate([r.idx for r in results]), err=np.concatenate([r.err for r in results]), stats=tot_stats) result = remove_duplicate_results(result) order = np.argsort(result.err) return ResultJIT(pos=result.pos[order], idx=result.idx[order], err=result.err[order], stats=result.stats)
def _grow_linear_mc_start(seconds, verts_pickleable, edges_pickleable, threadno, pbar, lbl, verbosity, merge_bblock, pbar_interval, **kwargs): tstart = time() verts = tuple([_Vertex(*vp) for vp in verts_pickleable]) edges = tuple([_Edge(*ep) for ep in edges_pickleable]) pos = np.empty(shape=(1024, len(verts), 4, 4), dtype=np.float32) idx = np.empty(shape=(1024, len(verts)), dtype=np.int32) err = np.empty(shape=(1024, ), dtype=np.float32) stats = zero_search_stats() result = ResultJIT(pos=pos, idx=idx, err=err, stats=stats) bases = np.zeros(len(verts), dtype=np.int64) del kwargs['nresults'] if threadno == 0 and pbar: desc = 'linear search ' + str(lbl) if merge_bblock is None: merge_bblock = 0 pbar_inst = tqdm(desc=desc, position=merge_bblock + 1, total=seconds, mininterval=pbar_interval) last = tstart nbatch = [1000, 330, 100, 33, 10, 3] + [1] * 99 nbatch = nbatch[len(edges)] * 10 nresults = 0 iter = 0 ndups = 0 while time() < tstart + seconds: if 'pbar_inst' in vars(): pbar_inst.update(time() - last) last = time() nresults, result = _grow_linear_mc(nbatch, result, verts, edges, bases=bases, nresults=nresults, **kwargs) iter += 1 # remove duplicates every 10th iter if iter % 10 == 0: nresults_with_dups = nresults uniq_result = ResultJIT(idx=result.idx[:nresults], pos=result.pos[:nresults], err=result.err[:nresults], stats=result.stats) uniq_result = remove_duplicate_results(uniq_result) nresults = len(uniq_result.err) result.idx[:nresults] = uniq_result.idx result.pos[:nresults] = uniq_result.pos result.err[:nresults] = uniq_result.err ndups += nresults_with_dups - nresults # print(ndups / nresults) if nresults >= kwargs['max_linear']: break if 'pbar_inst' in vars(): pbar_inst.close() result = ResultJIT(result.pos[:nresults], result.idx[:nresults], result.err[:nresults], result.stats) return result
def merge_results_bridge_long(criteria, critC, ssdag, ssdB, ssdC, rsltC, **kw): # look up rsltCs in critC hashtable to get Bs sizesB = np.array([len(v.ibblock) for v in ssdB.verts]) # print('merge_results_bridge_long') # print(' sizesB:', sizesB) # print(' sizes', [len(v.ibblock) for v in ssdag.verts]) idx_list = list() pos_list = list() err_list = list() missing = 0 for iresult in range(len(rsltC.err)): idxC = rsltC.idx[iresult] posC = rsltC.pos[iresult] xhat = posC[criteria.to_seg] @ np.linalg.inv(posC[criteria.from_seg]) xhat = xhat.astype(np.float64) key = critC.filter_binner(xhat) val = critC.filter_hash.get(key) assert val < np.prod(sizesB) if val == -123_456_789: missing += 1 continue idxB = decode_indices(sizesB, val) merge_ibblock = ssdB.verts[0].ibblock[idxB[0]] merge_ibblock_c = ssdC.verts[-1].ibblock[idxC[-1]] if merge_ibblock != merge_ibblock_c: continue merge_site1 = ssdB.verts[0].isite[idxB[0], 1] merge_site2 = ssdC.verts[-1].isite[idxC[-1], 0] # print(' merge_sites', iresult, merge_site1, merge_site2) if merge_site1 == merge_site2: continue merge_outres = ssdB.verts[0].ires[idxB[0], 1] merge_inres = ssdC.verts[-1].ires[idxC[-1], 0] iinC = [v.ires[i, 0] for v, i in zip(ssdC.verts, idxC)] iinB = [v.ires[i, 0] for v, i in zip(ssdB.verts, idxB)] ioutC = [v.ires[i, 1] for v, i in zip(ssdC.verts, idxC)] ioutB = [v.ires[i, 1] for v, i in zip(ssdB.verts, idxB)] ibbC = [v.ibblock[i] for v, i in zip(ssdC.verts, idxC)] ibbB = [v.ibblock[i] for v, i in zip(ssdB.verts, idxB)] # print(' hash stuff', iresult, key, val, idxB, ibbC, ibbB) # print(' iinC', iinC) # print(' iinB', iinB) # print(' ioutC', ioutC) # print(' ioutB', ioutB) # print(' ', merge_ibblock, merge_inres, merge_outres) imergeseg = len(idxC) - 1 vmerge = ssdag.verts[imergeseg] w = ((vmerge.ibblock == merge_ibblock) * (vmerge.ires[:, 0] == merge_inres) * (vmerge.ires[:, 1] == merge_outres)) imerge = np.where(w)[0] if len(imerge) is 0: print(" empty imerge") continue if len(imerge) > 1: print(" imerge", imerge) assert len(imerge) == 1 imerge = imerge[0] # print(' ', imerge) # print(' ', vmerge.ibblock[imerge], vmerge.ires[imerge]) idx = np.concatenate([idxC[:-1], [imerge], idxB[1:]]) # print(' idx', idx) # compute pos and err spos = np.eye(4) pos = list() for i, v in enumerate(ssdag.verts): index = idx[i] pos.append(spos @ v.x2orig[index]) spos = spos @ v.x2exit[index] pos = np.stack(pos) # print('posC') # for x in posC: # print(x) # print('pos') # for x in pos: # print(x) err = criteria.score(pos) # print(' err', err) if err > kw["tolerance"]: continue idx_list.append(idx) pos_list.append(pos) err_list.append(err) if missing > 0: print( "merge_results_bridge_long: xform missing from hash_table:", missing, "of", len(rsltC.err), ) assert missing == 0 if len(pos_list) > 0: return ResultJIT( np.stack(pos_list), np.stack(idx_list), np.array(err_list), SearchStats(0, 0, 0), ) return None
def prune_clashes( ssdag, crit, rslt, max_clash_check=-1, ca_clash_dis=4.0, parallel=False, approx=0, verbosity=0, merge_bblock=None, pbar=False, pbar_interval=10.0, context_structure=None, **kw, ): # print('todo: clash check should handle symmetry') if max_clash_check == 0: return rslt max_clash_check = min(max_clash_check, len(rslt.idx)) if max_clash_check < 0: max_clash_check = len(rslt.idx) if not pbar: print( f"mbb{f'{merge_bblock:04}' if merge_bblock else 'none'} checking clashes", max_clash_check, "of", len(rslt.err), ) verts = tuple(ssdag.verts) # exe = cf.ProcessPoolExecutor if parallel else InProcessExecutor exe = InProcessExecutor with exe() as pool: futures = list() for i in range(max_clash_check): dirns = tuple([v.dirn for v in verts]) iress = tuple([v.ires for v in verts]) chains = tuple([ ssdag.bbs[k][verts[k].ibblock[rslt.idx[i, k]]].chains for k in range(len(ssdag.verts)) ]) ncacs = tuple([ ssdag.bbs[k][verts[k].ibblock[rslt.idx[i, k]]].ncac for k in range(len(ssdag.verts)) ]) if isinstance(context_structure, ClashGrid): clash = False for pos, ncac in zip(rslt.pos[i], ncacs): xyz = pos @ ncac[..., None] if context_structure.clashcheck(xyz.squeeze()): clash = True break if clash: continue futures.append( pool.submit( _check_all_chain_clashes, dirns=dirns, iress=iress, idx=rslt.idx[i], pos=rslt.pos[i], chn=chains, ncacs=ncacs, thresh=ca_clash_dis * ca_clash_dis, approx=approx, )) futures[-1].index = i if pbar: desc = "checking clashes " if merge_bblock is not None and merge_bblock >= 0: desc = f"{desc} mbb{merge_bblock:04d}" if merge_bblock is None: merge_bblock = 0 futures = tqdm( cf.as_completed(futures), desc=desc, total=len(futures), mininterval=pbar_interval, position=merge_bblock + 1, ) ok = np.zeros(max_clash_check, dtype="?") for f in futures: ok[f.index] = f.result() return ResultJIT( rslt.pos[:max_clash_check][ok], rslt.idx[:max_clash_check][ok], rslt.err[:max_clash_check][ok], rslt.stats, )
def prune_clashes(ssdag, crit, rslt, max_clash_check=-1, ca_clash_dis=4.0, parallel=False, approx=0, verbosity=0, merge_bblock=None, pbar=False, pbar_interval=10.0, **kw): # print('todo: clash check should handle symmetry') if max_clash_check == 0: return rslt max_clash_check = min(max_clash_check, len(rslt.idx)) if max_clash_check < 0: max_clash_check = len(rslt.idx) if not pbar: print(f'mbb{merge_bblock:04} checking clashes', max_clash_check, 'of', len(rslt.err)) verts = tuple(ssdag.verts) # exe = cf.ProcessPoolExecutor if parallel else InProcessExecutor exe = InProcessExecutor with exe() as pool: futures = list() for i in range(max_clash_check): dirns = tuple([v.dirn for v in verts]) iress = tuple([v.ires for v in verts]) chains = tuple([ ssdag.bbs[k][verts[k].ibblock[rslt.idx[i, k]]].chains for k in range(len(ssdag.verts)) ]) ncacs = tuple([ ssdag.bbs[k][verts[k].ibblock[rslt.idx[i, k]]].ncac for k in range(len(ssdag.verts)) ]) futures.append( pool.submit(_check_all_chain_clashes, dirns=dirns, iress=iress, idx=rslt.idx[i], pos=rslt.pos[i], chn=chains, ncacs=ncacs, thresh=ca_clash_dis * ca_clash_dis, approx=approx)) futures[-1].index = i if pbar: desc = 'checking clashes ' if merge_bblock is not None and merge_bblock >= 0: desc = f'{desc} mbb{merge_bblock:04d}' if merge_bblock is None: merge_bblock = 0 futures = tqdm( cf.as_completed(futures), desc=desc, total=len(futures), mininterval=pbar_interval, position=merge_bblock + 1, ) ok = np.empty(len(futures), dtype='?') for f in futures: ok[f.index] = f.result() return ResultJIT(rslt.pos[:max_clash_check][ok], rslt.idx[:max_clash_check][ok], rslt.err[:max_clash_check][ok], rslt.stats)
def merge_results_concat(criteria, ssdag, ssdagA, rsltA, critB, ssdagB, rsltB, merged_err_cut, max_merge, **kw): bsfull = [x[0] for x in ssdag.bbspec] bspartA = [x[0] for x in ssdagA.bbspec] bspartB = [x[0] for x in ssdagB.bbspec] assert bsfull[-len(bspartA):] == bspartA assert bsfull[:len(bspartB)] == bspartB # print('merge_results_concat ssdag.bbspec', ssdag.bbspec) # print('merge_results_concat criteria.bbspec', criteria.bbspec) rsltB = subset_result(rsltB, slice(max_merge)) binner = critB.binner hash_table = critB.hash_table from_seg = criteria.from_seg assert len(ssdagB.bbs[-1]) == len(ssdagA.bbs[0]) assert len(ssdagB.bbs[-1]) == len(ssdag.bbs[from_seg]) assert len(ssdagB.bbs[-1]) == 1, "did you set merge_bblock?" assert ssdagB.bbs[-1][0].filehash == ssdagA.bbs[0][0].filehash assert ssdagB.bbs[-1][0].filehash == ssdag.bbs[from_seg][0].filehash for _ in range(from_seg): f = [bb.filehash for bb in ssdag.bbs[_]] assert f == [bb.filehash for bb in ssdagB.bbs[_]] for _ in range(len(ssdag.verts) - from_seg): f = [bb.filehash for bb in ssdag.bbs[from_seg + _]] assert f == [bb.filehash for bb in ssdagA.bbs[_]] n = len(rsltB.idx) nv = len(ssdag.verts) merged = ResultJIT( pos=np.empty((n, nv, 4, 4), dtype="f4"), idx=np.empty((n, nv), dtype="i4"), err=9e9 * np.ones((n, ), dtype="f8"), stats=np.empty(n, dtype="i4"), ) ok = np.ones(n, dtype=np.bool) for i_in_rslt in range(n): # print(rsltB.pos[i_in_rslt, -1]) val = _get_hash_val(binner, hash_table, rsltB.pos[i_in_rslt, -1], criteria.nfold) # print( # 'merge_results_concat', i_in_rslt, val, np.right_shift(val, 32), # np.right_shift(val, 16) % 16, # np.right_shift(val, 8) % 8, val % 8 # ) if val < 0: print("val < 0") ok[i_in_rslt] = False continue i_ot_rslt = np.right_shift(val, 32) assert i_ot_rslt < len(rsltA.idx) # check score asap pos = np.concatenate(( rsltB.pos[i_in_rslt, :-1], rsltB.pos[i_in_rslt, -1] @ rsltA.pos[i_ot_rslt, :], )) assert np.allclose(pos[from_seg], rsltB.pos[i_in_rslt, -1]) err = criteria.score(pos.reshape(-1, 1, 4, 4)) merged.err[i_in_rslt] = err # print('merge_results_concat', i_in_rslt, pos) # print('merge_results_concat', i_in_rslt, err) if err > merged_err_cut: continue i_outer = rsltA.idx[i_ot_rslt, 0] i_outer2 = rsltA.idx[i_ot_rslt, -1] i_inner = rsltB.idx[i_in_rslt, -1] v_inner = ssdagB.verts[-1] v_outer = ssdagA.verts[0] ibb = v_outer.ibblock[i_outer] assert ibb == 0 ires_in = v_inner.ires[i_inner, 0] ires_out = v_outer.ires[i_outer, 1] isite_in = v_inner.isite[i_inner, 0] isite_out = v_outer.isite[i_outer, 1] isite_out2 = ssdagA.verts[-1].isite[i_outer2, 0] mrgv = ssdag.verts[from_seg] assert max(mrgv.ibblock) == 0 assert max(ssdagA.verts[-1].ibblock) == 0 imerge = util.binary_search_pair(mrgv.ires, (ires_in, ires_out)) if imerge == -1: # if imerge < 0: ok[i_in_rslt] = False continue idx = np.concatenate( (rsltB.idx[i_in_rslt, :-1], [imerge], rsltA.idx[i_ot_rslt, 1:])) assert len(idx) == len(ssdag.verts) for ii, v in zip(idx, ssdag.verts): if v is not None: assert ii < v.len assert len(pos) == len(idx) == nv merged.pos[i_in_rslt] = pos merged.idx[i_in_rslt] = idx merged.stats[i_in_rslt] = i_ot_rslt # print(merged.err[:100]) nbad = np.sum(1 - ok) if nbad: print("bad imerge", nbad, "of", n) # print('bad score', np.sum(merged.err > merged_err_cut), 'of', n) ok[merged.err > merged_err_cut] = False ok = np.where(ok)[0][np.argsort(merged.err[ok])] merged = subset_result(merged, ok) return merged