def transform_star(self, star): 'transform the star' if star.a_mat is None: dims = star.lpi.get_num_cols() star.a_mat = np.identity(dims, dtype=self.dtype) star.bias = np.zeros(dims, dtype=self.dtype) cols = [] for col in range(star.a_mat.shape[1]): #print(f".transforming star: {col} / {star.a_mat.shape[1]})") vec = star.a_mat[:, col] vec = nn_unflatten(vec, self.input_shape) res = self.execute(vec) res = res - self.zero_output res = nn_flatten(res) cols.append(res) dtype = star.bias.dtype star.a_mat = np.array(cols, dtype=dtype).transpose() vec = nn_unflatten(star.bias, self.input_shape) res = self.execute(vec) star.bias = nn_flatten(res)
def make_init_box(min_image, max_image): 'make init box' flat_min_image = nn_flatten(min_image) flat_max_image = nn_flatten(max_image) assert flat_min_image.size == flat_max_image.size box = list(zip(flat_min_image, flat_max_image)) return box
def make_init(nn, image_filename, epsilon, specific_image=None): 'returns list of (image_id, image_data, classification_label, init_star_state, spec)' rv = [] images, labels = load_unscaled_images(image_filename, specific_image=specific_image) min_images, _ = load_unscaled_images(image_filename, specific_image=specific_image, epsilon=-epsilon) max_images, _ = load_unscaled_images(image_filename, specific_image=specific_image, epsilon=epsilon) print("making init states") for image_id, (image, classification) in enumerate(zip(images, labels)): output = nn.execute(image) flat_output = nn_flatten(output) num_outputs = flat_output.shape[0] label = np.argmax(flat_output) if label == labels[image_id]: # correctly classified # unsafe if classification is not maximal (anything else is > classfication) spec_list = [] for i in range(num_outputs): if i == classification: continue l = [0] * 10 l[classification] = 1 l[i] = -1 spec_list.append(Specification([l], [0])) spec = DisjunctiveSpec(spec_list) min_image = min_images[image_id] max_image = max_images[image_id] init_box = make_init_box(min_image, max_image) init_box = np.array(init_box, dtype=np.float32) init_state = LpStarState(init_box, spec) image_index = image_id if specific_image is not None: image_index = specific_image rv.append((image_index, image, label, init_state, spec)) return rv
def transform_zono(self, zono): 'transform the zono' zono_copy = zono.deep_copy() cols = [] for col in range(zono.mat_t.shape[1]): #print(f".transforming zono: {col} / {zono.mat_t.shape[1]})") vec = zono.mat_t[:, col] vec = nn_unflatten(vec, self.input_shape) res = self.execute(vec) res = res - self.zero_output res = nn_flatten(res) cols.append(res) dtype = zono.center.dtype zono.mat_t = np.array(cols, dtype=dtype).transpose() start_center = nn_unflatten(zono.center, self.input_shape) end_center = self.execute(start_center) zono.center = nn_flatten(end_center)
def make_prerelu_sims(ss, network): '''compute the prerelu simulation values at each remaining layer this only saves the state for the remaining layers, before relu is executed the output of the network is stored at index len(network.layers) returns a dict, layer_num -> sim_vector ''' if ss.prefilter.simulation is None: rv = None else: rv = {} state = ss.prefilter.simulation[1].copy() layer_num = ss.cur_layer layer = network.layers[layer_num] rv[layer_num] = state # current layer may be partially processed if isinstance(layer, ReluLayer): state = np.clip(state, 0, np.inf) while layer_num + 1 < len(network.layers): layer_num += 1 layer = network.layers[layer_num] rv[layer_num] = state shape = layer.get_input_shape() input_tensor = nn_unflatten(state, shape).astype(ss.star.a_mat.dtype) output_tensor = layer.execute(input_tensor) state = nn_flatten(output_tensor) # save final output rv[len(network.layers)] = state return rv
def gen_adv_single_threaded(network, remaining_secs): 'gen adversarial without multiprocessing interface' concrete_io_tuple = None start = time.perf_counter() _, aimage = try_quick_adversarial(Settings.ADVERSARIAL_QUICK_NUM_ATTEMPTS, remaining_secs) gen_time = time.perf_counter() - start if aimage is not None: if Settings.PRINT_OUTPUT: print("try_quick_adversarial found unsafe image") start = time.perf_counter() output = network.execute(aimage) flat_output = np.ravel(output) olabel = np.argmax(output) confirmed = olabel != Settings.ADVERSARIAL_ORIG_LABEL exec_time = time.perf_counter() - start if Settings.PRINT_OUTPUT: print( f"Original label: {Settings.ADVERSARIAL_ORIG_LABEL}, output argmax: {olabel}" ) gen_ms = f"{round(1000*gen_time, 1)}ms" exec_ms = f"{round(1000*exec_time, 1)}ms" print( f"counterexample was confirmed: {confirmed}. Gen: {gen_ms}, Exec: {exec_ms}" ) if confirmed: concrete_io_tuple = (nn_flatten(aimage), flat_output) return concrete_io_tuple
def find_concrete_io(self, star, branch_tuples): 'try to find a concrete input and output in star, that explores the passed-in branch_tuples' assert Settings.FIND_CONCRETE_COUNTEREXAMPLES Timers.tic('find_concrete_io') rv = None # solve lp to get the input/output res = star.minimize_vec(None, return_io=True) if res is not None: cinput, _ = res # try to confirm the counter-example full_cinput_flat = star.to_full_input(cinput).astype( star.a_mat.dtype) exec_output, exec_branch_list = self.shared.network.execute( full_cinput_flat, save_branching=True) exec_output = nn_flatten(exec_output) if branch_list_in_branch_tuples(exec_branch_list, branch_tuples): rv = full_cinput_flat, exec_output else: print( ". weakly-tested code: couldn't confirm countereample... tightening constraints" ) # try to make each of the constraints a little tighter star_copy = star.copy() rhs_original = star_copy.lpi.get_rhs() rhs = rhs_original.copy() # tighten by this factor tighten_factor = 1e-16 while tighten_factor < 1e16: # tighten the constraints a little for i, val in enumerate(rhs_original): rhs[i] = val - tighten_factor star_copy.lpi.set_rhs(rhs) res = star_copy.minimize_vec(None, return_io=True) if res is None: # infeasible break cinput, _ = res full_cinput_flat = star.to_full_input(cinput) full_cinput = nn_unflatten( full_cinput_flat, self.shared.network.get_input_shape()) exec_output, exec_branch_list = self.shared.network.execute( full_cinput, save_branching=True) exec_output = nn_flatten(exec_output) if branch_list_in_branch_tuples(exec_branch_list, branch_tuples): rv = full_cinput_flat, exec_output break # for next loop, tighten even more tighten_factor *= 10 Timers.toc('find_concrete_io') return rv
def consider_overapprox(self): '''conditionally run overapprox analysis this may set self.priv.ss to None if overapprox is safe returns is_safe (False does not mean unsafe, just that safe cannot be proven or timeout) ''' do_overapprox = False is_safe = False concrete_io_tuple = None ss = self.priv.ss network = self.shared.network spec = self.shared.spec assert ss.remaining_splits() > 0 if Settings.BRANCH_MODE == Settings.BRANCH_OVERAPPROX: do_overapprox = True elif Settings.BRANCH_MODE in [ Settings.BRANCH_EGO, Settings.BRANCH_EGO_LIGHT ]: do_overapprox = ss.should_try_overapprox # todo: experiment moving this after single-zono overapprox if do_overapprox and Settings.SPLIT_IF_IDLE and self.exists_idle_worker( ): do_overapprox = False if do_overapprox: # todo: experiment global timeout vs per-round timeout start = time.perf_counter() def check_cancel_func(): 'worker cancel func. can raise OverapproxCanceledException' if self.shared.should_exit.value: raise OverapproxCanceledException( f'shared.should_exit was true') #if Settings.SPLIT_IF_IDLE and self.exists_idle_worker(): # print("cancel idle") # raise OverapproxCanceledException('exists idle worker') now = time.perf_counter() if now - self.priv.start_time > Settings.TIMEOUT: raise OverapproxCanceledException('timeout exceeded') if now - start > Settings.OVERAPPROX_LP_TIMEOUT: raise OverapproxCanceledException('lp timeout exceeded') timer_name = 'do_overapprox_rounds' Timers.tic(timer_name) # compute simulation first (and make sure it's safe) prerelu_sims = make_prerelu_sims(ss, network) if prerelu_sims is None: concrete_io_tuple = None else: sim_out = prerelu_sims[len(network.layers)] if spec.is_violation(sim_out): sim_in_flat = ss.prefilter.simulation[0] sim_in = ss.star.to_full_input(sim_in_flat) sim_in = sim_in.astype(ss.star.a_mat.dtype) # run through complete network in to out before counting it sim_out = network.execute(sim_in) sim_out = nn_flatten(sim_out) if spec.is_violation(sim_out): concrete_io_tuple = [sim_in_flat, sim_out] if Settings.PRINT_OUTPUT: print( "\nOverapproximation found was a confirmed counterexample." ) print( f"\nUnsafe Base Branch: {self.priv.ss.branch_str()} (Mode: {Settings.BRANCH_MODE})" ) self.found_unsafe(concrete_io_tuple) self.add_branch_str('CONCRETE UNSAFE') if concrete_io_tuple is None: # sim was safe, proceed with overapproximation try: gen_limit = max(self.priv.max_approx_gen, Settings.OVERAPPROX_MIN_GEN_LIMIT) if Settings.OVERAPPROX_GEN_LIMIT_MULTIPLIER is None: gen_limit = np.inf num_branches = len(ss.branch_tuples) if num_branches > Settings.OVERAPPROX_NEAR_ROOT_MAX_SPLITS: otypes = Settings.OVERAPPROX_TYPES else: otypes = Settings.OVERAPPROX_TYPES_NEAR_ROOT res = do_overapprox_rounds( ss, network, spec, prerelu_sims, check_cancel_func, gen_limit, try_seeded_adversarial=self.try_seeded_adversarial, overapprox_types=otypes) if res.concrete_io_tuple is not None: if Settings.PRINT_OUTPUT: print( "\nviolation star found adversarial was a confirmed counterexample." ) print( f"\nUnsafe Base Branch: {self.priv.ss.branch_str()} (Mode: {Settings.BRANCH_MODE})" ) self.found_unsafe(res.concrete_io_tuple) self.add_branch_str('CONCRETE UNSAFE') else: is_safe = res.is_safe safe_str = "safe" if is_safe else "unsafe" self.add_branch_str(f"{safe_str} {res}") if not is_safe: ss.should_try_overapprox = False if Settings.OVERAPPROX_GEN_LIMIT_MULTIPLIER is not None: if is_safe: new_max = Settings.OVERAPPROX_GEN_LIMIT_MULTIPLIER * res.get_max_gens( ) self.priv.max_approx_gen = max( self.priv.max_approx_gen, new_max) else: self.priv.max_approx_gen = 0 # reset limit except OverapproxCanceledException as e: # fix timers while Timers.stack and Timers.stack[-1].name != timer_name: Timers.toc(Timers.stack[-1].name) self.add_branch_str(f'{str(e)}') self.priv.max_approx_gen = 0 # reset limit ss.should_try_overapprox = False Timers.toc(timer_name) ##### post overapproximation processing if Settings.PRINT_BRANCH_TUPLES: print(self.priv.branch_tuples_list[-1]) if is_safe or concrete_io_tuple is not None: # done with this branch if Settings.RESULT_SAVE_POLYS: self.save_poly(ss) self.priv.ss = None self.priv.finished_approx_stars += 1 # local stats that get updated in update_shared_variables self.priv.update_stars += 1 self.priv.update_work_frac += ss.work_frac self.priv.update_stars_in_progress -= 1 if not self.priv.work_list: # urgently update shared variables to try to get more work self.priv.shared_update_urgent = True self.priv.fulfillment_requested_time = time.perf_counter() return is_safe
def try_seeded_adversarial(self, dims, abstract_ios): ''' generate adversarial image from abstract counterexample seeds returns concrete_io_tuple or None ''' Timers.tic('try_seeded_adversarial') assert dims == Settings.ADVERSARIAL_ORIG_IMAGE.size for cinput, _ in abstract_ios: seed_image = nn_unflatten(cinput[:dims], Settings.ADVERSARIAL_ORIG_IMAGE.shape) concrete_io_tuple = None onnx_path = Settings.ADVERSARIAL_ONNX_PATH assert onnx_path is not None if self.priv.agen is None: # initialize ep = Settings.ADVERSARIAL_EPSILON im = Settings.ADVERSARIAL_ORIG_IMAGE l = Settings.ADVERSARIAL_ORIG_LABEL Timers.tic("AgenState init") self.priv.agen = AgenState(onnx_path, im, l, ep) Timers.toc("AgenState init") a = self.priv.agen.try_seeded(seed_image) if a is not None: aimage, ep = a if Settings.PRINT_OUTPUT: print( f"try_seeded_adversarial found violation image with ep={ep}" ) else: aimage = None if aimage is not None: flat_image = nn_flatten(aimage) output = self.shared.network.execute(flat_image) flat_output = np.ravel(output) olabel = np.argmax(output) confirmed = olabel != Settings.ADVERSARIAL_ORIG_LABEL if Settings.PRINT_OUTPUT: print( f"Original label: {Settings.ADVERSARIAL_ORIG_LABEL}, output argmax: {olabel}" ) print(f"counterexample was confirmed: {confirmed}") if confirmed: concrete_io_tuple = (flat_image, flat_output) break Timers.toc('try_seeded_adversarial') return concrete_io_tuple
def worker_func(worker_index, shared): 'worker function during verification' np.seterr( all='raise' ) # raise exceptions on floating-point errors instead of printing warnings if shared.multithreaded: reinit_onnx_sessions(shared.network) Timers.stack.clear() # reset inherited Timers tag = f" (Process {worker_index})" else: tag = "" timer_name = f'worker_func{tag}' Timers.tic(timer_name) priv = PrivateState(worker_index) priv.start_time = shared.start_time w = Worker(shared, priv) if worker_index == 1 and Settings.ADVERSARIAL_IN_WORKERS and Settings.ADVERSARIAL_ONNX_PATH: # while worker 0 does overapproximation, worker 1 priv.agen, aimage = try_quick_adversarial(1) for i in range(Settings.ADVERSARIAL_WORKERS_MAX_ITER): if aimage is not None: if Settings.PRINT_OUTPUT: print( f"mixed_adversarial worker {worker_index} found unsafe image after on iteration {i}" ) flat_image = nn_flatten(aimage) output = w.shared.network.execute(flat_image) flat_output = np.ravel(output) olabel = np.argmax(output) confirmed = olabel != Settings.ADVERSARIAL_ORIG_LABEL if Settings.PRINT_OUTPUT: print( f"Original label: {Settings.ADVERSARIAL_ORIG_LABEL}, output argmax: {olabel}" ) print(f"counterexample was confirmed: {confirmed}") if confirmed: concrete_io_tuple = (flat_image, flat_output) w.found_unsafe(concrete_io_tuple) break if shared.should_exit.value != 0: break #if shared.finished_initial_overapprox.value == 1 and worker_index != 1: # worker 1 finishes all attempts, other works help with enumeration # break # try again using a mixed strategy random_attacks_only = False aimage = priv.agen.try_mixed_adversarial(i, random_attacks_only) try: w.main_loop() if worker_index == 0 and Settings.PRINT_OUTPUT: print("\n") if Settings.SAVE_BRANCH_TUPLES_FILENAME is not None: with open(Settings.SAVE_BRANCH_TUPLES_FILENAME, 'w') as f: for line in w.priv.branch_tuples_list: f.write(f'{line}\n') if not Settings.TIMING_STATS: f.write( f"\nNo timing stats recorded because Settings.TIMING_STATS was False" ) else: f.write("\nStats:\n") as_timer_list = Timers.top_level_timer.get_children_recursive( 'advance') fs_timer_list = Timers.top_level_timer.get_children_recursive( 'finished_star') to_timer_list = Timers.top_level_timer.get_children_recursive( 'do_overapprox_rounds') if as_timer_list: as_timer = as_timer_list[0] exact_secs = as_timer.total_secs if fs_timer_list: exact_secs += fs_timer_list[0].total_secs else: exact_secs = 0 if to_timer_list: to_timer = to_timer_list[0] o_secs = to_timer.total_secs else: o_secs = 0 total_secs = exact_secs + o_secs f.write(f"Total time: {round(total_secs, 3)} ({round(o_secs, 3)} overapprox, " + \ f"{round(exact_secs, 3)} exact)\n") t = round(w.priv.total_overapprox_ms / 1000, 3) f.write( f"Sum total time for ONLY safe overapproxations (optimal): {t}\n" ) Timers.toc(timer_name) if shared.multithreaded and not shared.had_exception.value: if worker_index != 0 and Settings.PRINT_OUTPUT and Settings.TIMING_STATS: time.sleep( 0.2 ) # delay to try to let worker 0 print timing stats first ############################## shared.mutex.acquire() # use mutex so printing doesn't get interrupted # computation time is sum of advance_star and finished_star if Settings.TIMING_STATS: as_timer_list = Timers.top_level_timer.get_children_recursive( 'advance') fs_timer_list = Timers.top_level_timer.get_children_recursive( 'finished_star') to_timer_list = Timers.top_level_timer.get_children_recursive( 'do_overapprox_rounds') as_secs = as_timer_list[0].total_secs if as_timer_list else 0 fs_secs = fs_timer_list[0].total_secs if fs_timer_list else 0 to_secs = to_timer_list[0].total_secs if to_timer_list else 0 secs = as_secs + fs_secs exact_percent = 100 * secs / Timers.top_level_timer.total_secs over_percent = 100 * to_secs / Timers.top_level_timer.total_secs sum_percent = exact_percent + over_percent if Settings.PRINT_OUTPUT: if w.priv.total_fulfillment_count > 0: t = w.priv.total_fulfillment_time else: t = 0 e_stars = w.priv.finished_stars a_stars = w.priv.finished_approx_stars tot_stars = e_stars + a_stars print(f"Worker {worker_index}: {tot_stars} stars ({e_stars} exact, {a_stars} approx); " + \ f"Working: {round(sum_percent, 1)}% (Exact: {round(exact_percent, 1)}%, " + \ f"Overapprox: {round(over_percent, 1)}%); " + \ f"Waiting: {round(1000*t, 3)}ms ") shared.mutex.release() ############################## if Settings.PRINT_OUTPUT and Settings.TIMING_STATS and \ Settings.NUM_PROCESSES > 1 and worker_index == 0: time.sleep(0.4) print("") Timers.print_stats() print("") except: if Settings.PRINT_OUTPUT: print("\n") traceback.print_exc() shared.mutex.acquire() shared.had_exception.value = True shared.should_exit.value = True shared.mutex.release() print(f"\nWorker {worker_index} had exception") w.clear_remaining_work() # dump branch tuples if Settings.SAVE_BRANCH_TUPLES_FILENAME is not None: with open(Settings.SAVE_BRANCH_TUPLES_FILENAME, 'w') as f: for line in w.priv.branch_tuples_list: f.write(f'{line}\n') # fix timers while Timers.stack and Timers.stack[-1].name != timer_name: Timers.toc(Timers.stack[-1].name) Timers.toc(timer_name)