def refine(rec: Rec, diagsearch, pedantic=False): if rec.is_point: return [rec] elif rec.degenerate: drop_fb = False rec2 = to_rec((_midpoint(i) for i in rec.intervals), error=0) else: drop_fb = True result_type, rec2 = diagsearch(rec) if pedantic and result_type != SearchResultType.NON_TRIVIAL: raise RuntimeError(f"Threshold function does not intersect {rec}.") elif result_type == SearchResultType.TRIVIALLY_FALSE: return [to_rec(zip(rec.bot, rec.bot))] elif result_type == SearchResultType.TRIVIALLY_TRUE: return [to_rec(zip(rec.top, rec.top))] return list(rec.subdivide(rec2, drop_fb=drop_fb))
def _hausdorff_approxes(r1: Rec, r2: Rec, f1, f2, *, metric=hausdorff_bounds): recs1, recs2 = {bounding_box(r1, f1)}, {bounding_box(r2, f2)} refiner1, refiner2 = _refiner(f1), _refiner(f2) next(refiner1), next(refiner2) while True: d, (recs1, recs2) = metric(recs1, recs2) recs1 = set.union(*(set(refiner1.send(r)) for r in recs1)) recs2 = set.union(*(set(refiner2.send(r)) for r in recs2)) # TODO: for each rectangle, also add it's bot and top recs1 |= {to_rec(zip(r.bot, r.bot)) for r in recs1} | {to_rec(zip(r.top, r.top)) for r in recs1} recs2 |= {to_rec(zip(r.bot, r.bot)) for r in recs2} | {to_rec(zip(r.top, r.top)) for r in recs2} yield d, (recs1, recs2)
def refine(rec: Rec, diagsearch): if rec.is_point: return [rec] elif rec.degenerate: drop_fb = False rec2 = to_rec(_midpoint(i) for i in rec.intervals) else: drop_fb = True result_type, rec2 = diagsearch(rec) if result_type != SearchResultType.NON_TRIVIAL: raise RuntimeError(f"Threshold function does not intersect {rec}.") return list(rec.subdivide(rec2, drop_fb=drop_fb))
def bounding_box(r: Rec, oracle): """Compute Bounding box. TODO: clean up""" recs = list(box_edges(r)) tops = [(binsearch(r2, oracle)[1].top, tuple((np.array(r2.top) - np.array(r2.bot) != 0))) for r2 in recs] tops = fn.group_by(ig(1), tops) def _top_components(): for key, vals in tops.items(): idx = key.index(True) yield max(v[0][idx] for v in vals) top = np.array(list(_top_components())) intervals = tuple(zip(r.bot, top)) return to_rec(intervals=intervals)
def box_edges(r): """Produce all n*2**(n-1) edges. TODO: clean up """ n = len(r.bot) diag = np.array(r.top) - np.array(r.bot) bot = np.array(r.bot) xs = [ np.array(x) for x in product([1, 0], repeat=n - 1) if x.count(1) != n ] def _corner_edge_masks(i): for x in xs: s_mask = np.insert(x, i, 0) t_mask = np.insert(x, i, 1) yield s_mask, t_mask for s_mask, t_mask in fn.mapcat(_corner_edge_masks, range(n)): intervals = tuple(zip(bot + s_mask * diag, bot + t_mask * diag)) yield to_rec(intervals=intervals)
def binsearch(r: Rec, oracle, eps=1e-4) -> SearchResult: """Binary search over the diagonal of the rectangle. Returns the lower and upper approximation on the diagonal. """ f = diagonal_convex_comb(r) feval = fn.compose(oracle, f) lo, hi = 0, 1 # Early termination via bounds checks if feval(lo): result_type = SearchResultType.TRIVIALLY_TRUE elif not feval(hi): result_type = SearchResultType.TRIVIALLY_FALSE else: result_type = SearchResultType.NON_TRIVIAL mid = lo while (f(hi) - f(lo) > eps).any(): mid = lo + (hi - lo) / 2 lo, hi = (lo, mid) if feval(mid) else (mid, hi) return result_type, to_rec(zip(f(lo), f(hi)))