def enumerate_network(init, network, spec=None): '''enumerate the branches in the network init can either be a 2-d list or an lp_star or an lp_star_state settings are controlled by assigning directly to the class Settings, for example "Settings.timeout = 10" if spec is not None, a verification problem will be considered for the provided Specification object the output is an instance of Result ''' assert Settings.TIMEOUT is not None, "use Settings.TIMEOUT = np.inf for no timeout" assert Settings.OVERAPPROX_LP_TIMEOUT is not None, "use Settings.OVERAPPROX_LP_TIMEOUT = np.inf for no timeout" if Settings.CHECK_SINGLE_THREAD_BLAS: check_openblas_threads() Timers.reset() Timers.tic('enumerate_network') start = time.perf_counter() if Settings.BRANCH_MODE != Settings.BRANCH_EXACT: assert spec is not None, "spec required for overapproximation analysis" if not Settings.EAGER_BOUNDS: assert Settings.SPLIT_ORDER == Settings.SPLIT_INORDER assert not Settings.RESULT_SAVE_TIMERS or Settings.TIMING_STATS, \ "RESULT_SAVE_TIMERS cannot be used if TIMING_STATS is False" if not Settings.TIMING_STATS: Timers.disable() # adversarial generation process and queue concrete_io_tuple = None q = None p = None found_adv = None if Settings.ADVERSARIAL_ONNX_PATH is not None and Settings.ADVERSARIAL_TRY_QUICK: q = multiprocessing.Queue() found_adv = multiprocessing.Value('i', 0) p = multiprocessing.Process(target=gen_adv, args=(q, found_adv, network, Settings.TIMEOUT)) p.start() # don't wait for result... run safety check in parallel init_ss = None if concrete_io_tuple is None and time.perf_counter( ) - start < Settings.TIMEOUT: init_ss = make_init_ss(init, network, spec, start) # returns None if timeout proven_safe = False try_quick = Settings.TRY_QUICK_OVERAPPROX or Settings.SINGLE_SET if init_ss is not None and try_quick and spec is not None: proven_safe, concrete_io_tuple = try_quick_overapprox( init_ss, network, spec, start, found_adv) if concrete_io_tuple is not None: # non-parallel adversarial example was generated if Settings.PRINT_OUTPUT: print("Proven unsafe before enumerate") rv = Result(network, quick=True) rv.result_str = 'unsafe' rv.cinput = concrete_io_tuple[0] rv.coutput = concrete_io_tuple[1] elif init_ss is None or time.perf_counter() - start > Settings.TIMEOUT: if Settings.PRINT_OUTPUT: print( f"Timeout before enumerate, init_ss is None: {init_ss is None}" ) rv = Result(network, quick=True) rv.result_str = 'timeout' elif proven_safe: if Settings.PRINT_OUTPUT: print("Proven safe before enumerate") rv = Result(network, quick=True) rv.result_str = 'safe' else: concrete_io_tuple = None if p is not None and q is not None: concrete_io_tuple = q.get() p.join() p.terminate() q.cancel_join_thread() p = None q = None if concrete_io_tuple is not None: if Settings.PRINT_OUTPUT: print("Initial quick adversarial search found unsafe image.") rv = Result(network, quick=True) rv.result_str = 'unsafe' rv.cinput = concrete_io_tuple[0] rv.coutput = concrete_io_tuple[1] elif Settings.SINGLE_SET: if Settings.PRINT_OUTPUT: print("SINGLE_SET analysis inconclusive.") rv = Result(network, quick=True) rv.result_str = 'none' else: num_workers = 1 if Settings.NUM_PROCESSES < 1 else Settings.NUM_PROCESSES shared = SharedState(network, spec, num_workers, start) shared.push_init(init_ss) if shared.result.result_str != 'safe': # easy specs can be proven safe in push_init() Timers.tic('run workers') if num_workers == 1: if Settings.PRINT_OUTPUT: print("Running single-threaded") worker_func(0, shared) else: processes = [] if Settings.PRINT_OUTPUT: print( f"Running in parallel with {num_workers} processes" ) for index in range(Settings.NUM_PROCESSES): p = multiprocessing.Process(target=worker_func, args=(index, shared)) p.start() processes.append(p) for p in processes: p.join() Timers.toc('run workers') rv = shared.result rv.total_secs = time.perf_counter() - start process_result(shared) if p is not None and q is not None: q.cancel_join_thread() p.terminate() p = None q = None if rv.total_secs is None: rv.total_secs = time.perf_counter() - start Timers.toc('enumerate_network') if Settings.TIMING_STATS and Settings.PRINT_OUTPUT and rv.result_str != 'error': Timers.print_stats() return rv
def worker_func(worker_index, shared): 'worker function during verification' np.seterr( all='raise' ) # raise exceptions on floating-point errors instead of printing warnings if shared.multithreaded: reinit_onnx_sessions(shared.network) Timers.stack.clear() # reset inherited Timers tag = f" (Process {worker_index})" else: tag = "" timer_name = f'worker_func{tag}' Timers.tic(timer_name) priv = PrivateState(worker_index) priv.start_time = shared.start_time w = Worker(shared, priv) if worker_index == 1 and Settings.ADVERSARIAL_IN_WORKERS and Settings.ADVERSARIAL_ONNX_PATH: # while worker 0 does overapproximation, worker 1 priv.agen, aimage = try_quick_adversarial(1) for i in range(Settings.ADVERSARIAL_WORKERS_MAX_ITER): if aimage is not None: if Settings.PRINT_OUTPUT: print( f"mixed_adversarial worker {worker_index} found unsafe image after on iteration {i}" ) flat_image = nn_flatten(aimage) output = w.shared.network.execute(flat_image) flat_output = np.ravel(output) olabel = np.argmax(output) confirmed = olabel != Settings.ADVERSARIAL_ORIG_LABEL if Settings.PRINT_OUTPUT: print( f"Original label: {Settings.ADVERSARIAL_ORIG_LABEL}, output argmax: {olabel}" ) print(f"counterexample was confirmed: {confirmed}") if confirmed: concrete_io_tuple = (flat_image, flat_output) w.found_unsafe(concrete_io_tuple) break if shared.should_exit.value != 0: break #if shared.finished_initial_overapprox.value == 1 and worker_index != 1: # worker 1 finishes all attempts, other works help with enumeration # break # try again using a mixed strategy random_attacks_only = False aimage = priv.agen.try_mixed_adversarial(i, random_attacks_only) try: w.main_loop() if worker_index == 0 and Settings.PRINT_OUTPUT: print("\n") if Settings.SAVE_BRANCH_TUPLES_FILENAME is not None: with open(Settings.SAVE_BRANCH_TUPLES_FILENAME, 'w') as f: for line in w.priv.branch_tuples_list: f.write(f'{line}\n') if not Settings.TIMING_STATS: f.write( f"\nNo timing stats recorded because Settings.TIMING_STATS was False" ) else: f.write("\nStats:\n") as_timer_list = Timers.top_level_timer.get_children_recursive( 'advance') fs_timer_list = Timers.top_level_timer.get_children_recursive( 'finished_star') to_timer_list = Timers.top_level_timer.get_children_recursive( 'do_overapprox_rounds') if as_timer_list: as_timer = as_timer_list[0] exact_secs = as_timer.total_secs if fs_timer_list: exact_secs += fs_timer_list[0].total_secs else: exact_secs = 0 if to_timer_list: to_timer = to_timer_list[0] o_secs = to_timer.total_secs else: o_secs = 0 total_secs = exact_secs + o_secs f.write(f"Total time: {round(total_secs, 3)} ({round(o_secs, 3)} overapprox, " + \ f"{round(exact_secs, 3)} exact)\n") t = round(w.priv.total_overapprox_ms / 1000, 3) f.write( f"Sum total time for ONLY safe overapproxations (optimal): {t}\n" ) Timers.toc(timer_name) if shared.multithreaded and not shared.had_exception.value: if worker_index != 0 and Settings.PRINT_OUTPUT and Settings.TIMING_STATS: time.sleep( 0.2 ) # delay to try to let worker 0 print timing stats first ############################## shared.mutex.acquire() # use mutex so printing doesn't get interrupted # computation time is sum of advance_star and finished_star if Settings.TIMING_STATS: as_timer_list = Timers.top_level_timer.get_children_recursive( 'advance') fs_timer_list = Timers.top_level_timer.get_children_recursive( 'finished_star') to_timer_list = Timers.top_level_timer.get_children_recursive( 'do_overapprox_rounds') as_secs = as_timer_list[0].total_secs if as_timer_list else 0 fs_secs = fs_timer_list[0].total_secs if fs_timer_list else 0 to_secs = to_timer_list[0].total_secs if to_timer_list else 0 secs = as_secs + fs_secs exact_percent = 100 * secs / Timers.top_level_timer.total_secs over_percent = 100 * to_secs / Timers.top_level_timer.total_secs sum_percent = exact_percent + over_percent if Settings.PRINT_OUTPUT: if w.priv.total_fulfillment_count > 0: t = w.priv.total_fulfillment_time else: t = 0 e_stars = w.priv.finished_stars a_stars = w.priv.finished_approx_stars tot_stars = e_stars + a_stars print(f"Worker {worker_index}: {tot_stars} stars ({e_stars} exact, {a_stars} approx); " + \ f"Working: {round(sum_percent, 1)}% (Exact: {round(exact_percent, 1)}%, " + \ f"Overapprox: {round(over_percent, 1)}%); " + \ f"Waiting: {round(1000*t, 3)}ms ") shared.mutex.release() ############################## if Settings.PRINT_OUTPUT and Settings.TIMING_STATS and \ Settings.NUM_PROCESSES > 1 and worker_index == 0: time.sleep(0.4) print("") Timers.print_stats() print("") except: if Settings.PRINT_OUTPUT: print("\n") traceback.print_exc() shared.mutex.acquire() shared.had_exception.value = True shared.should_exit.value = True shared.mutex.release() print(f"\nWorker {worker_index} had exception") w.clear_remaining_work() # dump branch tuples if Settings.SAVE_BRANCH_TUPLES_FILENAME is not None: with open(Settings.SAVE_BRANCH_TUPLES_FILENAME, 'w') as f: for line in w.priv.branch_tuples_list: f.write(f'{line}\n') # fix timers while Timers.stack and Timers.stack[-1].name != timer_name: Timers.toc(Timers.stack[-1].name) Timers.toc(timer_name)