Ejemplo n.º 1
0
def _grow_linear_start(bb_base, verts_pickleable, edges_pickleable, **kwargs):
    verts = tuple([_Vertex(*vp) for vp in verts_pickleable])
    edges = tuple([_Edge(*ep) for ep in edges_pickleable])
    pos = np.empty(shape=(1024, len(verts), 4, 4), dtype=np.float32)
    idx = np.empty(shape=(1024, len(verts)), dtype=np.int32)
    err = np.empty(shape=(1024, ), dtype=np.float32)
    stats = zero_search_stats()
    result = ResultJIT(pos=pos, idx=idx, err=err, stats=stats)
    bases = np.zeros(len(verts), dtype=np.int64)
    nresults, result = _grow_linear_recurse(result=result,
                                            bb_base=bb_base,
                                            verts=verts,
                                            edges=edges,
                                            bases=bases,
                                            **kwargs)
    result = ResultJIT(result.pos[:nresults], result.idx[:nresults],
                       result.err[:nresults], result.stats)
    return result
Ejemplo n.º 2
0
def grow_linear(ssdag,
                loss_function=null_lossfunc,
                loss_threshold=2.0,
                last_bb_same_as=-1,
                parallel=0,
                monte_carlo=0,
                verbosity=0,
                merge_bblock=None,
                lbl='',
                pbar=False,
                pbar_interval=10.0,
                no_duplicate_bases=True,
                max_linear=1000000,
                **kw):
    verts = ssdag.verts
    edges = ssdag.edges
    if last_bb_same_as is None: last_bb_same_as = -1
    assert len(verts) > 1
    assert len(verts) == len(edges) + 1
    assert verts[0].dirn[0] == 2
    assert verts[-1].dirn[1] == 2
    for ivertex in range(len(verts) - 1):
        assert verts[ivertex].dirn[1] + verts[ivertex + 1].dirn[0] == 1

    # if isinstance(loss_function, types.FunctionType):
    #     if not 'NUMBA_DISABLE_JIT' in os.environ:
    #         loss_function = nb.njit(nogil=1, fastmath=1)

    exe = cf.ThreadPoolExecutor(
        max_workers=parallel) if parallel else InProcessExecutor()
    # exe = cf.ProcessPoolExecutor(max_workers=parallel) if parallel else InProcessExecutor()
    with exe as pool:
        bb_base = tuple([
            np.array([b.basehash if no_duplicate_bases else 0 for b in bb],
                     dtype=np.int64) for bb in ssdag.bbs
        ])
        verts_pickleable = [v._state for v in verts]
        edges_pickleable = [e._state for e in edges]
        kwargs = dict(
            bb_base=bb_base,
            verts_pickleable=verts_pickleable,
            edges_pickleable=edges_pickleable,
            loss_function=loss_function,
            loss_threshold=loss_threshold,
            last_bb_same_as=last_bb_same_as,
            nresults=0,
            isplice=0,
            splice_position=np.eye(4, dtype=vertex_xform_dtype),
            max_linear=max_linear,
        )
        futures = list()
        if monte_carlo:
            kwargs['fn'] = _grow_linear_mc_start
            kwargs['seconds'] = monte_carlo
            kwargs['ivertex_range'] = (0, verts[0].len)
            kwargs['merge_bblock'] = merge_bblock
            kwargs['lbl'] = lbl
            kwargs['verbosity'] = verbosity
            kwargs['pbar'] = pbar
            kwargs['pbar_interval'] = pbar_interval
            njob = cpu_count() if parallel else 1
            for ivert in range(njob):
                kwargs['threadno'] = ivert
                futures.append(pool.submit(**kwargs))
        else:
            kwargs['fn'] = _grow_linear_start
            nbatch = max(1, int(verts[0].len / 64 / cpu_count()))
            for ivert in range(0, verts[0].len, nbatch):
                ivert_end = min(verts[0].len, ivert + nbatch)
                kwargs['ivertex_range'] = ivert, ivert_end
                futures.append(pool.submit(**kwargs))
        results = list()
        if monte_carlo:
            for f in cf.as_completed(futures):
                results.append(f.result())
        else:
            desc = 'linear search ' + str(lbl)
            if merge_bblock is None: merge_bblock = 0
            fiter = cf.as_completed(futures)
            if pbar:
                fiter = tqdm(fiter,
                             desc=desc,
                             position=merge_bblock + 1,
                             mininterval=pbar_interval,
                             total=len(futures))
            for f in fiter:
                results.append(f.result())
    tot_stats = zero_search_stats()
    for i in range(len(tot_stats)):
        tot_stats[i][0] += sum([r.stats[i][0] for r in results])
    result = ResultJIT(pos=np.concatenate([r.pos for r in results]),
                       idx=np.concatenate([r.idx for r in results]),
                       err=np.concatenate([r.err for r in results]),
                       stats=tot_stats)
    result = remove_duplicate_results(result)
    order = np.argsort(result.err)
    return ResultJIT(pos=result.pos[order],
                     idx=result.idx[order],
                     err=result.err[order],
                     stats=result.stats)
Ejemplo n.º 3
0
def _grow_linear_mc_start(seconds, verts_pickleable, edges_pickleable,
                          threadno, pbar, lbl, verbosity, merge_bblock,
                          pbar_interval, **kwargs):
    tstart = time()
    verts = tuple([_Vertex(*vp) for vp in verts_pickleable])
    edges = tuple([_Edge(*ep) for ep in edges_pickleable])
    pos = np.empty(shape=(1024, len(verts), 4, 4), dtype=np.float32)
    idx = np.empty(shape=(1024, len(verts)), dtype=np.int32)
    err = np.empty(shape=(1024, ), dtype=np.float32)
    stats = zero_search_stats()
    result = ResultJIT(pos=pos, idx=idx, err=err, stats=stats)
    bases = np.zeros(len(verts), dtype=np.int64)
    del kwargs['nresults']

    if threadno == 0 and pbar:
        desc = 'linear search ' + str(lbl)
        if merge_bblock is None: merge_bblock = 0
        pbar_inst = tqdm(desc=desc,
                         position=merge_bblock + 1,
                         total=seconds,
                         mininterval=pbar_interval)
        last = tstart

    nbatch = [1000, 330, 100, 33, 10, 3] + [1] * 99
    nbatch = nbatch[len(edges)] * 10
    nresults = 0
    iter = 0
    ndups = 0
    while time() < tstart + seconds:
        if 'pbar_inst' in vars():
            pbar_inst.update(time() - last)
            last = time()
        nresults, result = _grow_linear_mc(nbatch,
                                           result,
                                           verts,
                                           edges,
                                           bases=bases,
                                           nresults=nresults,
                                           **kwargs)

        iter += 1
        # remove duplicates every 10th iter
        if iter % 10 == 0:
            nresults_with_dups = nresults
            uniq_result = ResultJIT(idx=result.idx[:nresults],
                                    pos=result.pos[:nresults],
                                    err=result.err[:nresults],
                                    stats=result.stats)
            uniq_result = remove_duplicate_results(uniq_result)
            nresults = len(uniq_result.err)
            result.idx[:nresults] = uniq_result.idx
            result.pos[:nresults] = uniq_result.pos
            result.err[:nresults] = uniq_result.err
            ndups += nresults_with_dups - nresults
            # print(ndups / nresults)

        if nresults >= kwargs['max_linear']: break

    if 'pbar_inst' in vars(): pbar_inst.close()

    result = ResultJIT(result.pos[:nresults], result.idx[:nresults],
                       result.err[:nresults], result.stats)
    return result
Ejemplo n.º 4
0
def merge_results_bridge_long(criteria, critC, ssdag, ssdB, ssdC, rsltC, **kw):

    # look up rsltCs in critC hashtable to get Bs

    sizesB = np.array([len(v.ibblock) for v in ssdB.verts])
    # print('merge_results_bridge_long')
    # print('    sizesB:', sizesB)
    # print('    sizes', [len(v.ibblock) for v in ssdag.verts])

    idx_list = list()
    pos_list = list()
    err_list = list()
    missing = 0
    for iresult in range(len(rsltC.err)):
        idxC = rsltC.idx[iresult]
        posC = rsltC.pos[iresult]
        xhat = posC[criteria.to_seg] @ np.linalg.inv(posC[criteria.from_seg])
        xhat = xhat.astype(np.float64)
        key = critC.filter_binner(xhat)
        val = critC.filter_hash.get(key)
        assert val < np.prod(sizesB)
        if val == -123_456_789:
            missing += 1
            continue
        idxB = decode_indices(sizesB, val)
        merge_ibblock = ssdB.verts[0].ibblock[idxB[0]]
        merge_ibblock_c = ssdC.verts[-1].ibblock[idxC[-1]]
        if merge_ibblock != merge_ibblock_c:
            continue
        merge_site1 = ssdB.verts[0].isite[idxB[0], 1]
        merge_site2 = ssdC.verts[-1].isite[idxC[-1], 0]
        # print('    merge_sites', iresult, merge_site1, merge_site2)
        if merge_site1 == merge_site2:
            continue

        merge_outres = ssdB.verts[0].ires[idxB[0], 1]
        merge_inres = ssdC.verts[-1].ires[idxC[-1], 0]

        iinC = [v.ires[i, 0] for v, i in zip(ssdC.verts, idxC)]
        iinB = [v.ires[i, 0] for v, i in zip(ssdB.verts, idxB)]
        ioutC = [v.ires[i, 1] for v, i in zip(ssdC.verts, idxC)]
        ioutB = [v.ires[i, 1] for v, i in zip(ssdB.verts, idxB)]

        ibbC = [v.ibblock[i] for v, i in zip(ssdC.verts, idxC)]
        ibbB = [v.ibblock[i] for v, i in zip(ssdB.verts, idxB)]
        # print('    hash stuff', iresult, key, val, idxB, ibbC, ibbB)
        # print('    iinC', iinC)
        # print('    iinB', iinB)
        # print('    ioutC', ioutC)
        # print('    ioutB', ioutB)
        # print('    ', merge_ibblock, merge_inres, merge_outres)

        imergeseg = len(idxC) - 1
        vmerge = ssdag.verts[imergeseg]
        w = ((vmerge.ibblock == merge_ibblock) *
             (vmerge.ires[:, 0] == merge_inres) *
             (vmerge.ires[:, 1] == merge_outres))
        imerge = np.where(w)[0]
        if len(imerge) is 0:
            print("    empty imerge")
            continue
        if len(imerge) > 1:
            print("    imerge", imerge)
            assert len(imerge) == 1
        imerge = imerge[0]
        # print('    ', imerge)
        # print('    ', vmerge.ibblock[imerge], vmerge.ires[imerge])
        idx = np.concatenate([idxC[:-1], [imerge], idxB[1:]])
        # print('    idx', idx)

        # compute pos and err
        spos = np.eye(4)
        pos = list()
        for i, v in enumerate(ssdag.verts):
            index = idx[i]
            pos.append(spos @ v.x2orig[index])
            spos = spos @ v.x2exit[index]
        pos = np.stack(pos)
        # print('posC')
        # for x in posC:
        # print(x)
        # print('pos')
        # for x in pos:
        # print(x)
        err = criteria.score(pos)
        # print('    err', err)
        if err > kw["tolerance"]:
            continue

        idx_list.append(idx)
        pos_list.append(pos)
        err_list.append(err)

    if missing > 0:
        print(
            "merge_results_bridge_long: xform missing from hash_table:",
            missing,
            "of",
            len(rsltC.err),
        )
    assert missing == 0

    if len(pos_list) > 0:
        return ResultJIT(
            np.stack(pos_list),
            np.stack(idx_list),
            np.array(err_list),
            SearchStats(0, 0, 0),
        )
    return None
Ejemplo n.º 5
0
def prune_clashes(
      ssdag,
      crit,
      rslt,
      max_clash_check=-1,
      ca_clash_dis=4.0,
      parallel=False,
      approx=0,
      verbosity=0,
      merge_bblock=None,
      pbar=False,
      pbar_interval=10.0,
      context_structure=None,
      **kw,
):
   # print('todo: clash check should handle symmetry')
   if max_clash_check == 0:
      return rslt
   max_clash_check = min(max_clash_check, len(rslt.idx))
   if max_clash_check < 0:
      max_clash_check = len(rslt.idx)

   if not pbar:
      print(
         f"mbb{f'{merge_bblock:04}' if merge_bblock else 'none'} checking clashes",
         max_clash_check,
         "of",
         len(rslt.err),
      )

   verts = tuple(ssdag.verts)
   # exe = cf.ProcessPoolExecutor if parallel else InProcessExecutor
   exe = InProcessExecutor
   with exe() as pool:
      futures = list()
      for i in range(max_clash_check):
         dirns = tuple([v.dirn for v in verts])
         iress = tuple([v.ires for v in verts])
         chains = tuple([
            ssdag.bbs[k][verts[k].ibblock[rslt.idx[i, k]]].chains for k in range(len(ssdag.verts))
         ])
         ncacs = tuple([
            ssdag.bbs[k][verts[k].ibblock[rslt.idx[i, k]]].ncac for k in range(len(ssdag.verts))
         ])
         if isinstance(context_structure, ClashGrid):
            clash = False
            for pos, ncac in zip(rslt.pos[i], ncacs):
               xyz = pos @ ncac[..., None]
               if context_structure.clashcheck(xyz.squeeze()):
                  clash = True
                  break
            if clash:
               continue

         futures.append(
            pool.submit(
               _check_all_chain_clashes,
               dirns=dirns,
               iress=iress,
               idx=rslt.idx[i],
               pos=rslt.pos[i],
               chn=chains,
               ncacs=ncacs,
               thresh=ca_clash_dis * ca_clash_dis,
               approx=approx,
            ))
         futures[-1].index = i

      if pbar:
         desc = "checking clashes "
         if merge_bblock is not None and merge_bblock >= 0:
            desc = f"{desc}    mbb{merge_bblock:04d}"
         if merge_bblock is None:
            merge_bblock = 0
         futures = tqdm(
            cf.as_completed(futures),
            desc=desc,
            total=len(futures),
            mininterval=pbar_interval,
            position=merge_bblock + 1,
         )

      ok = np.zeros(max_clash_check, dtype="?")
      for f in futures:
         ok[f.index] = f.result()

   return ResultJIT(
      rslt.pos[:max_clash_check][ok],
      rslt.idx[:max_clash_check][ok],
      rslt.err[:max_clash_check][ok],
      rslt.stats,
   )
Ejemplo n.º 6
0
def prune_clashes(ssdag,
                  crit,
                  rslt,
                  max_clash_check=-1,
                  ca_clash_dis=4.0,
                  parallel=False,
                  approx=0,
                  verbosity=0,
                  merge_bblock=None,
                  pbar=False,
                  pbar_interval=10.0,
                  **kw):
    # print('todo: clash check should handle symmetry')
    if max_clash_check == 0:
        return rslt
    max_clash_check = min(max_clash_check, len(rslt.idx))
    if max_clash_check < 0: max_clash_check = len(rslt.idx)

    if not pbar:
        print(f'mbb{merge_bblock:04} checking clashes', max_clash_check, 'of',
              len(rslt.err))

    verts = tuple(ssdag.verts)
    # exe = cf.ProcessPoolExecutor if parallel else InProcessExecutor
    exe = InProcessExecutor
    with exe() as pool:
        futures = list()
        for i in range(max_clash_check):
            dirns = tuple([v.dirn for v in verts])
            iress = tuple([v.ires for v in verts])
            chains = tuple([
                ssdag.bbs[k][verts[k].ibblock[rslt.idx[i, k]]].chains
                for k in range(len(ssdag.verts))
            ])
            ncacs = tuple([
                ssdag.bbs[k][verts[k].ibblock[rslt.idx[i, k]]].ncac
                for k in range(len(ssdag.verts))
            ])
            futures.append(
                pool.submit(_check_all_chain_clashes,
                            dirns=dirns,
                            iress=iress,
                            idx=rslt.idx[i],
                            pos=rslt.pos[i],
                            chn=chains,
                            ncacs=ncacs,
                            thresh=ca_clash_dis * ca_clash_dis,
                            approx=approx))
            futures[-1].index = i

        if pbar:
            desc = 'checking clashes '
            if merge_bblock is not None and merge_bblock >= 0:
                desc = f'{desc}    mbb{merge_bblock:04d}'
            if merge_bblock is None:
                merge_bblock = 0
            futures = tqdm(
                cf.as_completed(futures),
                desc=desc,
                total=len(futures),
                mininterval=pbar_interval,
                position=merge_bblock + 1,
            )

        ok = np.empty(len(futures), dtype='?')
        for f in futures:
            ok[f.index] = f.result()

    return ResultJIT(rslt.pos[:max_clash_check][ok],
                     rslt.idx[:max_clash_check][ok],
                     rslt.err[:max_clash_check][ok], rslt.stats)
Ejemplo n.º 7
0
def merge_results_concat(criteria, ssdag, ssdagA, rsltA, critB, ssdagB, rsltB,
                         merged_err_cut, max_merge, **kw):
    bsfull = [x[0] for x in ssdag.bbspec]
    bspartA = [x[0] for x in ssdagA.bbspec]
    bspartB = [x[0] for x in ssdagB.bbspec]
    assert bsfull[-len(bspartA):] == bspartA
    assert bsfull[:len(bspartB)] == bspartB

    # print('merge_results_concat ssdag.bbspec', ssdag.bbspec)
    # print('merge_results_concat criteria.bbspec', criteria.bbspec)
    rsltB = subset_result(rsltB, slice(max_merge))

    binner = critB.binner
    hash_table = critB.hash_table
    from_seg = criteria.from_seg

    assert len(ssdagB.bbs[-1]) == len(ssdagA.bbs[0])
    assert len(ssdagB.bbs[-1]) == len(ssdag.bbs[from_seg])
    assert len(ssdagB.bbs[-1]) == 1, "did you set merge_bblock?"
    assert ssdagB.bbs[-1][0].filehash == ssdagA.bbs[0][0].filehash
    assert ssdagB.bbs[-1][0].filehash == ssdag.bbs[from_seg][0].filehash
    for _ in range(from_seg):
        f = [bb.filehash for bb in ssdag.bbs[_]]
        assert f == [bb.filehash for bb in ssdagB.bbs[_]]
    for _ in range(len(ssdag.verts) - from_seg):
        f = [bb.filehash for bb in ssdag.bbs[from_seg + _]]
        assert f == [bb.filehash for bb in ssdagA.bbs[_]]

    n = len(rsltB.idx)
    nv = len(ssdag.verts)
    merged = ResultJIT(
        pos=np.empty((n, nv, 4, 4), dtype="f4"),
        idx=np.empty((n, nv), dtype="i4"),
        err=9e9 * np.ones((n, ), dtype="f8"),
        stats=np.empty(n, dtype="i4"),
    )
    ok = np.ones(n, dtype=np.bool)
    for i_in_rslt in range(n):
        # print(rsltB.pos[i_in_rslt, -1])
        val = _get_hash_val(binner, hash_table, rsltB.pos[i_in_rslt, -1],
                            criteria.nfold)
        # print(
        # 'merge_results_concat', i_in_rslt, val, np.right_shift(val, 32),
        # np.right_shift(val, 16) % 16,
        # np.right_shift(val, 8) % 8, val % 8
        # )
        if val < 0:
            print("val < 0")
            ok[i_in_rslt] = False
            continue
        i_ot_rslt = np.right_shift(val, 32)
        assert i_ot_rslt < len(rsltA.idx)

        # check score asap
        pos = np.concatenate((
            rsltB.pos[i_in_rslt, :-1],
            rsltB.pos[i_in_rslt, -1] @ rsltA.pos[i_ot_rslt, :],
        ))
        assert np.allclose(pos[from_seg], rsltB.pos[i_in_rslt, -1])
        err = criteria.score(pos.reshape(-1, 1, 4, 4))
        merged.err[i_in_rslt] = err
        # print('merge_results_concat', i_in_rslt, pos)
        # print('merge_results_concat', i_in_rslt, err)
        if err > merged_err_cut:
            continue

        i_outer = rsltA.idx[i_ot_rslt, 0]
        i_outer2 = rsltA.idx[i_ot_rslt, -1]
        i_inner = rsltB.idx[i_in_rslt, -1]
        v_inner = ssdagB.verts[-1]
        v_outer = ssdagA.verts[0]
        ibb = v_outer.ibblock[i_outer]
        assert ibb == 0
        ires_in = v_inner.ires[i_inner, 0]
        ires_out = v_outer.ires[i_outer, 1]
        isite_in = v_inner.isite[i_inner, 0]
        isite_out = v_outer.isite[i_outer, 1]
        isite_out2 = ssdagA.verts[-1].isite[i_outer2, 0]
        mrgv = ssdag.verts[from_seg]
        assert max(mrgv.ibblock) == 0
        assert max(ssdagA.verts[-1].ibblock) == 0

        imerge = util.binary_search_pair(mrgv.ires, (ires_in, ires_out))
        if imerge == -1:
            # if imerge < 0:
            ok[i_in_rslt] = False
            continue
        idx = np.concatenate(
            (rsltB.idx[i_in_rslt, :-1], [imerge], rsltA.idx[i_ot_rslt, 1:]))
        assert len(idx) == len(ssdag.verts)
        for ii, v in zip(idx, ssdag.verts):
            if v is not None:
                assert ii < v.len
        assert len(pos) == len(idx) == nv
        merged.pos[i_in_rslt] = pos
        merged.idx[i_in_rslt] = idx
        merged.stats[i_in_rslt] = i_ot_rslt
    # print(merged.err[:100])
    nbad = np.sum(1 - ok)
    if nbad:
        print("bad imerge", nbad, "of", n)
    # print('bad score', np.sum(merged.err > merged_err_cut), 'of', n)
    ok[merged.err > merged_err_cut] = False
    ok = np.where(ok)[0][np.argsort(merged.err[ok])]
    merged = subset_result(merged, ok)
    return merged