def __init__(self, dom: Optional[AbsDom], bound_mins: Union[Tensor, List[float]], bound_maxs: Union[Tensor, List[float]], larger_category: int): """ :param bound_mins: input lower bounds :param bound_maxs: input upper bounds :param larger_category: the output category that is supposed to be larger """ if not isinstance(bound_mins, Tensor): bound_mins = torch.tensor(bound_mins) if not isinstance(bound_maxs, Tensor): bound_maxs = torch.tensor(bound_maxs) valid_lb_ub(bound_mins, bound_maxs) self.bound_mins = bound_mins self.bound_maxs = bound_maxs self.larger_category = larger_category ''' In planet and marabou they try to find cex with y0 <= 0, which means safety prop is y0 >= 0 (accept y0=0 as well). ''' super().__init__('CollisionProp', dom, safe_fn='cols_is_max', viol_fn='cols_not_max', fn_args=(self.larger_category, )) return
def _covered(new_lb: Tensor, new_ub: Tensor, new_label: Tensor) -> bool: """ Returns True if the new LB/UB is already covered by some intersected piece. Assuming new_lb/new_ub is from X or Y. So there won't be intersection, thus just check subset? is sufficient. Assuming all params are not-batched. """ for i in range(len(shared_lbs)): shared_lb, shared_ub, shared_label = shared_lbs[i], shared_ubs[i], shared_labels[i] if valid_lb_ub(shared_lb, new_lb) and valid_lb_ub(new_ub, shared_ub): assert torch.equal(new_label | shared_label, shared_label), 'New intersected cube got more props?!' return True return False
def __init__(self, lb: Tensor, ub: Tensor): """ In Vanilla Interval domain, only the Lower Bounds and Upper Bounds are maintained. """ assert valid_lb_ub(lb, ub) self._lb = lb self._ub = ub return
def gen_rnd_points(lb: Tensor, ub: Tensor, extra: Optional[Tensor], K: int) -> Tuple[Tensor, Optional[Tensor]]: """ Different from old sample_points(), the output here maintains the structure. Also accepts extra. :param lb: Lower bounds, batched :param ub: Upper bounds, batched :param extra: e.g., bitmaps for properties in AndProp :param K: how many states to per abstraction :return: Batch x K x State, together with expanded extra """ assert valid_lb_ub(lb, ub) assert K >= 1 new_size = list(lb.size()) new_size.insert(1, K) # Batch x States => Batch x K x States base = lb.unsqueeze(dim=1).expand(*new_size) width = (ub - lb).unsqueeze(dim=1).expand(*new_size) coefs = torch.rand_like(base) pts = base + coefs * width if extra is None: new_extra = None else: new_size = list(extra.size()) new_size.insert(1, K) new_extra = extra.unsqueeze(dim=1).expand(*new_size) return pts, new_extra
def __init__(self, boxes_lb: Tensor, boxes_ub: Tensor, boxes_extra: Tensor = None): assert valid_lb_ub(boxes_lb, boxes_ub) self.boxes_lb = boxes_lb self.boxes_ub = boxes_ub self.boxes_extra = boxes_extra super().__init__(self.boxes_lb) return
def lbub_intersect(lb1: Tensor, ub1: Tensor, lb2: Tensor, ub2: Tensor) -> Tuple[Tensor, Tensor]: """ Return intersected [lb1, ub1] logic-and [lb2, ub2], or raise ValueError when they do not overlap. :param lb1, ub1, lb2, ub2: not batched :return: not batched tensors """ assert lb1.size() == lb2.size() and ub1.size() == ub2.size() res_lb, _ = torch.max(torch.stack((lb1, lb2), dim=-1), dim=-1) res_ub, _ = torch.min(torch.stack((ub1, ub2), dim=-1), dim=-1) if not valid_lb_ub(res_lb, res_ub): raise ValueError('Intersection failed.') return res_lb, res_ub
def total_area(lb: Tensor, ub: Tensor, eps: float = 1e-8, by_batch: bool = False) -> float: """ Return the total area constrained by LB/UB. Area = \Sum_{batch}{ \Prod{Element} }. :param lb: <Batch x ...> :param ub: <Batch x ...> :param by_batch: if True, return the areas of individual abstractions """ assert valid_lb_ub(lb, ub) diff = ub - lb diff += eps # some dimensions may be degenerated, then * 0 becomes 0. while diff.dim() > 1: diff = diff.prod(dim=-1) if by_batch: return diff else: return diff.sum().item()
def sample_points(lb: Tensor, ub: Tensor, K: int) -> Tensor: """ Uniformly sample K points for each region. Resulting in large batch of states. :param lb: Lower bounds, batched :param ub: Upper bounds, batched :param K: how many pieces to sample :return: (Batch * K) x State """ assert valid_lb_ub(lb, ub) assert K >= 1 repeat_dims = [1] * (len(lb.size()) - 1) base = lb.repeat(K, *repeat_dims) # repeat K times in the batch, preserving the rest dimensions width = (ub - lb).repeat(K, *repeat_dims) coefs = torch.rand_like(base) pts = base + coefs * width return pts
def join_all(self, props: List[AbsProp]): """ Conjoin multiple properties altogether. Now that each property may have different input space and different safety / violation distance functions. This method will re-arrange and determine the boundaries of sub- regions and which properties they should satisfy. """ nprops = len(props) assert nprops > 0 # initialize for 1st prop orig_label = torch.eye(nprops).byte() # showing each input region which properties they should obey lbs, ubs = props[0].lbub() labels = orig_label[[0]].expand(len(lbs), nprops) for i, prop in enumerate(props): if i == 0: continue new_lbs, new_ubs = prop.lbub() assert valid_lb_ub(new_lbs, new_ubs) new_labels = orig_label[[i]].expand(len(new_lbs), nprops) lbs, ubs, labels = self._join(lbs, ubs, labels, new_lbs, new_ubs, new_labels) return lbs, ubs, labels
def split( self, lb: Tensor, ub: Tensor, extra: Optional[Tensor], forward_fn: nn.Module, batch_size: int, stop_on_k_all: int = None, stop_on_k_new: int = None, stop_on_k_ops: int = None, tiny_width: float = None, collapse_res: bool = True ) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]: """ Different from verify(), split() does breadth-first traversal. Its objective is to have roughly even abstractions with small safety losses for the optimization later. :param lb: could be accessed from props.lbub(), but may need additional normalization :param ub: same as @param lb :param extra: could contain extra info such as the bit vectors for each LB/UB cube showing which safety property it should satisfy in AndProp; or just None :param forward_fn: differentiable forward propagation, not passing in net and call net(input) because different applications may have different net(input, **kwargs) :param batch_size: How many to bisect once at most, must provide this granularity. Larger batch_size => faster to compute but less precise / averaged (due to more rushing). :param stop_on_k_all: if not None, split() stops after total amount of abstractions exceeds this bar. :param stop_on_k_new: if not None, split() stops after the amount of abstractions introduced by this split() call exceeds this bar. :param stop_on_k_ops: if not None, split() stops after this many refinement steps have been applied. :param tiny_width: if not None, stop refining one dimension if its width is already <= this bar, e.g., setting tiny_width=1e-3 would ensure all refined abstraction dimension width > 5e-4. :return: <LB, UB> when extra is None, otherwise <LB, UB, extra> """ assert valid_lb_ub(lb, ub) assert batch_size > 0 def _validate_stop_criterion(v, pivot: int): assert v is None or (isinstance(v, int) and v > pivot) return _validate_stop_criterion(stop_on_k_all, 0) _validate_stop_criterion(stop_on_k_new, 0) _validate_stop_criterion( stop_on_k_ops, -1) # allow 0 refinement steps, i.e., just evaluate, no refine def empty() -> Tensor: return empty_like(lb) n_orig_abs = len(lb) # Not storing viol_lb anymore, as viol_dist is no longer computed. Those violated regions are still refined. wl_lb, wl_ub = empty(), empty() safe_lb, safe_ub = empty(), empty() tiny_lb, tiny_ub = empty(), empty() wl_extra = None if extra is None else empty().byte() safe_extra = None if extra is None else empty().byte() tiny_extra = None if extra is None else empty().byte() wl_safe_dist, wl_grad = empty(), empty() new_lb, new_ub, new_extra = lb, ub, extra iter = 0 while True: iter += 1 if len(new_lb) > 0: with torch.no_grad(): ''' It's important to have no_grad() here, otherwise the GPU memory will keep growing. With no_grad(), the GPU memory usage is stable. enable_grad() is called inside for grad computation. ''' new_grad, new_safe_dist = self._grad_dists_of( new_lb, new_ub, new_extra, forward_fn, batch_size) logging.debug( f'At iter {iter}, another {len(new_lb)} boxes are processed.' ) # process safe abstractions here rather than later (new_safe_lb, new_safe_ub, new_safe_extra), (rem_lb, rem_ub, rem_extra, rem_safe_dist, rem_grad) =\ self._transfer_safe(new_lb, new_ub, new_extra, new_safe_dist, new_grad) logging.debug(f'In which {len(new_safe_lb)} confirmed safe.') safe_lb = cat0(safe_lb, new_safe_lb) safe_ub = cat0(safe_ub, new_safe_ub) safe_extra = cat0(safe_extra, new_safe_extra) if tiny_width is not None: (new_tiny_lb, new_tiny_ub, new_tiny_extra), (rem_lb, rem_ub, rem_extra, rem_safe_dist, rem_grad) =\ self._transfer_tiny(rem_lb, rem_ub, rem_extra, rem_safe_dist, rem_grad, tiny_width) tiny_lb = cat0(tiny_lb, new_tiny_lb) tiny_ub = cat0(tiny_ub, new_tiny_ub) tiny_extra = cat0(tiny_extra, new_tiny_extra) logging.debug( f'In which {len(new_tiny_lb)} confirmed tiny.') wl_lb = cat0(wl_lb, rem_lb) wl_ub = cat0(wl_ub, rem_ub) wl_extra = cat0(wl_extra, rem_extra) wl_safe_dist = cat0(wl_safe_dist, rem_safe_dist) wl_grad = cat0(wl_grad, rem_grad) logging.debug( f'After iter {iter}, total #{len(safe_lb)} safe, #{len(wl_lb)} in worklist, ' + f'total #{len(tiny_lb)} too small and ignored.') if len(wl_lb) == 0: # nothing to bisect anymore break logging.debug( f'At iter {iter}, worklist safe dist min: {wl_safe_dist.min()}, max: {wl_safe_dist.max()}.' ) n_curr_abs = len(safe_lb) + len(tiny_lb) + len(wl_lb) if stop_on_k_all is not None and n_curr_abs >= stop_on_k_all: # has collected enough abstractions break if stop_on_k_new is not None and n_curr_abs - n_orig_abs >= stop_on_k_new: # has collected enough new abstractions break if stop_on_k_ops is not None and iter > stop_on_k_ops: # has run enough refinement iterations break ''' Pick large loss boxes to bisect first for splitting, so as to generate evenly distributed areas. There is no need to check if entire wl is selected, topk() should do that automatically (I suppose). ''' tmp = self._pick_top(batch_size, wl_lb, wl_ub, wl_extra, wl_safe_dist, wl_grad, largest=True) batch_lb, batch_ub, batch_extra, batch_grad = tmp[:4] wl_lb, wl_ub, wl_extra, wl_grad, wl_safe_dist = tmp[4:] new_lb, new_ub, new_extra = by_smear(batch_lb, batch_ub, batch_extra, batch_grad) pass # end of worklist while logging.debug( f'\nAt the end, split {len(wl_lb)} uncertain (non-zero loss) boxes, ' + f'{len(safe_lb)} safe boxes and {len(tiny_lb)} tiny boxes.') if len(wl_lb) > 0: logging.debug( f'Non zero loss boxes have safe loss min {wl_safe_dist.min()} ~ max {wl_safe_dist.max()}.' ) if collapse_res: with torch.no_grad(): all_lb = cat0(wl_lb, safe_lb, tiny_lb) all_ub = cat0(wl_ub, safe_ub, tiny_ub) all_extra = cat0(wl_extra, safe_extra, tiny_extra) if all_extra is None: return all_lb, all_ub else: return all_lb, all_ub, all_extra else: with torch.no_grad(): wl_lb = cat0(wl_lb, tiny_lb) wl_ub = cat0(wl_ub, tiny_ub) wl_extra = cat0(wl_extra, tiny_extra) if wl_extra is None: return wl_lb, wl_ub else: return wl_lb, wl_ub, wl_extra
def verify(self, lb: Tensor, ub: Tensor, extra: Optional[Tensor], forward_fn: nn.Module, batch_size: int = 4096) -> Optional[Tensor]: """ Verify the safety property or return some found counterexamples. The major difference with split() is that verify() does depth-first-search, checking smaller loss abstractions first. Otherwise, the memory consumption of BFS style refinement will explode. Also, tiny_width is not considered in verify(), it aims to enumerate however small areas, anyway. :param lb: Batch x ... :param ub: Batch x ... :param extra: could contain extra info such as the bit vectors for each LB/UB cube showing which safety property it should satisfy in AndProp; or just None :param forward_fn: differentiable forward propagation, not passing in net and call net(input) because different applications may have different net(input, **kwargs) :param batch_size: how many to bisect once at one time :return: (batched) counterexample tensors, if not None """ assert valid_lb_ub(lb, ub) assert batch_size > 0 # track how much have been certified tot_area = total_area(lb, ub) assert tot_area > 0 safes_area = 0. t0 = timer() def empty() -> Tensor: return empty_like(lb) # no need to save safe_lb/safe_ub wl_lb, wl_ub = empty(), empty() wl_extra = None if extra is None else empty().byte() wl_safe_dist, wl_grad = empty(), empty() new_lb, new_ub, new_extra = lb, ub, extra iter = 0 while True: iter += 1 if len(new_lb) > 0: ''' It's important to have no_grad() here, otherwise the GPU memory will keep growing. With no_grad(), the GPU memory usage is stable. enable_grad() is called inside for grad computation. viol_dist is now removed, because if viol_dist can certify violation, sampling can absolutely do the same, vice NOT versa. So there is no need to compute viol_dist anymore. It also shows that using 'safe' is slightly better than 'viol' as source based on first two hard instances in acas-hard. I also tried using a 'factor' tensor before, with LB = LB * factor and UB = UB * factor, to compute gradient w.r.t. 'factor'. However, that is much worse than the grad w.r.t. LB and UB directly. One possible reason is that 'factor' can only shrink the space in one direction towards its mid point. This has little to do with actual bisection later on. Grads w.r.t. LB/UB is more directly related. ''' with torch.no_grad(): new_grad, new_safe_dist = self._grad_dists_of( new_lb, new_ub, new_extra, forward_fn, batch_size) logging.debug( f'At iter {iter}, another {len(new_lb)} boxes are processed.' ) # process safe abstractions here rather than later (new_safe_lb, new_safe_ub, _), (rem_lb, rem_ub, rem_extra, rem_safe_dist, rem_grad) =\ self._transfer_safe(new_lb, new_ub, new_extra, new_safe_dist, new_grad) logging.debug(f'In which {len(new_safe_lb)} confirmed safe.') new_safes_area = total_area(new_safe_lb, new_safe_ub) safes_area += new_safes_area if len(rem_lb) > 0: # sample check the rest and add to worklist cex = self._sample_check(rem_lb, rem_ub, rem_extra, forward_fn) if cex is not None: # found cex! logging.debug(f'CEX found by sampling: {cex}') return cex wl_lb = cat0(wl_lb, rem_lb) wl_ub = cat0(wl_ub, rem_ub) wl_extra = cat0(wl_extra, rem_extra) wl_safe_dist = cat0(wl_safe_dist, rem_safe_dist) wl_grad = cat0(wl_grad, rem_grad) safe_area_percent = safes_area / tot_area * 100 wl_area_percent = 100. - safe_area_percent logging.debug( f'After iter {iter}, {pp_time(timer() - t0)}, total ({safe_area_percent:.2f}%) safe, ' + f'total #{len(wl_lb)} ({wl_area_percent:.2f}%) in worklist.') # logging.debug(pp_cuda_mem()) if len(wl_lb) == 0: # nothing to bisect anymore break logging.debug( f'In worklist, safe dist min: {wl_safe_dist.min()}, max: {wl_safe_dist.max()}.' ) ''' Pick small loss boxes to bisect first for verification, otherwise BFS style consumes huge memory. There is no need to check if entire wl is selected, topk() should do that automatically (I suppose). ''' tmp = self._pick_top(batch_size, wl_lb, wl_ub, wl_extra, wl_safe_dist, wl_grad, largest=False) batch_lb, batch_ub, batch_extra, batch_grad = tmp[:4] wl_lb, wl_ub, wl_extra, wl_grad, wl_safe_dist = tmp[4:] new_lb, new_ub, new_extra = by_smear(batch_lb, batch_ub, batch_extra, batch_grad) return None
def verify(self, lb: Tensor, ub: Tensor, extra: Optional[Tensor], forward_fn: nn.Module, batch_size: int = 200) -> Optional[Tensor]: """ Verify the safety property or return some found counterexamples. The major difference with split() is that verify() does depth-first-search, checking smaller loss abstractions first. Otherwise, the memory consumption of BFS style refinement will explode. Also, tiny_width is not considered in verify(), it aims to enumerate however small areas, anyway. :param lb: Batch x ... :param ub: Batch x ... :param extra: could contain extra info such as the bit vectors for each LB/UB cube showing which safety property it should satisfy in AndProp; or just None :param forward_fn: differentiable forward propagation, not passing in net and call net(input) because different applications may have different net(input, **kwargs) :param batch_size: how many abstractions are checked safe at a time :param sample_size: how many points are sampled per abstraction for refinement :return: (batched) counterexample tensors, if not None """ assert valid_lb_ub(lb, ub) assert batch_size > 0 # track how much have been certified tot_area = total_area(lb, ub) assert tot_area > 0 safes_area = 0. t0 = timer() def empty() -> Tensor: return empty_like(lb) # no need to save safe_lb/safe_ub wl_lb, wl_ub = empty(), empty() wl_extra = None if extra is None else empty().byte() wl_safe_dist = empty() new_lb, new_ub, new_extra = lb, ub, extra iter = 0 while True: iter += 1 if len(new_lb) > 0: ''' It's important to have no_grad() here, otherwise the GPU memory will keep growing. With no_grad(), the GPU memory usage is stable. enable_grad() is called inside for grad computation. ''' with torch.no_grad(): new_safe_dist = self._dists_of(new_lb, new_ub, new_extra, forward_fn, batch_size) logging.debug(f'At iter {iter}, another {len(new_lb)} boxes are processed.') # process safe abstractions here rather than later (new_safe_lb, new_safe_ub, _), (rem_lb, rem_ub, rem_extra, rem_safe_dist) =\ self._transfer_safe(new_lb, new_ub, new_extra, new_safe_dist) logging.debug(f'In which {len(new_safe_lb)} confirmed safe.') new_safes_area = total_area(new_safe_lb, new_safe_ub) safes_area += new_safes_area ''' It was sampling to check cex here, right after processing new abstractions in bisecter.py, here the sampling can be left later until the sampling for clustering. ''' wl_lb = cat0(wl_lb, rem_lb) wl_ub = cat0(wl_ub, rem_ub) wl_extra = cat0(wl_extra, rem_extra) wl_safe_dist = cat0(wl_safe_dist, rem_safe_dist) safe_area_percent = safes_area / tot_area * 100 wl_area_percent = 100. - safe_area_percent logging.debug(f'After iter {iter}, {pp_time(timer() - t0)}, total ({safe_area_percent:.2f}%) safe, ' + f'total #{len(wl_lb)} ({wl_area_percent:.2f}%) in worklist.') # logging.debug(pp_cuda_mem()) if len(wl_lb) == 0: # nothing to bisect anymore break logging.debug(f'In worklist, safe dist min: {wl_safe_dist.min()}, max: {wl_safe_dist.max()}.') ''' Pick small loss boxes to bisect first for verification, otherwise BFS style consumes huge memory. There is no need to check if entire wl is selected, topk() should do that automatically (I suppose). ''' tmp = self._pick_top(batch_size, wl_lb, wl_ub, wl_extra, wl_safe_dist, largest=False) batch_lb, batch_ub, batch_extra = tmp[:3] wl_lb, wl_ub, wl_extra, wl_safe_dist = tmp[3:] # refine these batch_lb/ubs ''' One alternative is to generate grid points using torch.linspace() and/or torch.meshgrid(). But that can only generate one meshgrid for one abstraction at a time, which should be slower than the batched version of generating random points for all given abstractions at once. Moreover, the random points way also allows directly control on how many points per abstraction are sampled. If using meshgrid, the exact number of sampled points are growing exponentially as the dimension increases. ''' # sampled_pts, sampled_extra = gen_rnd_points(batch_lb, batch_ub, batch_extra, K=sample_size) sampled_pts, sampled_extra = gen_vtx_points(batch_lb, batch_ub, batch_extra) # faster using vertices logging.debug(f'From {len(batch_lb)} abstractions, sampled points shape {sampled_pts.shape}.') with torch.no_grad(): sampled_outs = forward_fn(sampled_pts) # check cex from sampled points old_shape = list(sampled_outs.shape) viol_dist = self.prop.viol_dist_conc(sampled_outs.flatten(0, 1), sampled_extra) viol_bits = viol_dist <= 0. if viol_bits.any(): cex = sampled_pts[viol_bits] logging.debug(f'CEX found by sampling: {cex}') cex = cex.flatten(0, 1) # Batch x K x States => (Batch * K) x States return cex sampled_outs = sampled_outs.view(*old_shape) tmp_t0 = timer() refined_outs = self.by_clustering(batch_lb, batch_ub, batch_extra, sampled_pts, sampled_outs) logging.debug(f'Refinement in total takes {pp_time(timer() - tmp_t0)}') new_lb, new_ub = refined_outs[:2] new_extra = None if batch_extra is None else refined_outs[2] return None
def gamma(self) -> Tuple[Tensor, Tensor]: """ Transform the abstract elements back into Lower Bounds and Upper Bounds. """ lb = self.lb() ub = self.ub() assert valid_lb_ub(lb, ub) return lb, ub