示例#1
0
def enumerate_network(init, network, spec=None):
    '''enumerate the branches in the network

    init can either be a 2-d list or an lp_star or an lp_star_state

    settings are controlled by assigning directly to the class Settings, for example "Settings.timeout = 10"
    if spec is not None, a verification problem will be considered for the provided Specification object

    the output is an instance of Result
    '''

    assert Settings.TIMEOUT is not None, "use Settings.TIMEOUT = np.inf for no timeout"
    assert Settings.OVERAPPROX_LP_TIMEOUT is not None, "use Settings.OVERAPPROX_LP_TIMEOUT = np.inf for no timeout"

    if Settings.CHECK_SINGLE_THREAD_BLAS:
        check_openblas_threads()

    Timers.reset()
    Timers.tic('enumerate_network')
    start = time.perf_counter()

    if Settings.BRANCH_MODE != Settings.BRANCH_EXACT:
        assert spec is not None, "spec required for overapproximation analysis"

    if not Settings.EAGER_BOUNDS:
        assert Settings.SPLIT_ORDER == Settings.SPLIT_INORDER

    assert not Settings.RESULT_SAVE_TIMERS or Settings.TIMING_STATS, \
        "RESULT_SAVE_TIMERS cannot be used if TIMING_STATS is False"

    if not Settings.TIMING_STATS:
        Timers.disable()

    # adversarial generation process and queue
    concrete_io_tuple = None
    q = None
    p = None
    found_adv = None

    if Settings.ADVERSARIAL_ONNX_PATH is not None and Settings.ADVERSARIAL_TRY_QUICK:
        q = multiprocessing.Queue()
        found_adv = multiprocessing.Value('i', 0)

        p = multiprocessing.Process(target=gen_adv,
                                    args=(q, found_adv, network,
                                          Settings.TIMEOUT))
        p.start()
        # don't wait for result... run safety check in parallel

    init_ss = None

    if concrete_io_tuple is None and time.perf_counter(
    ) - start < Settings.TIMEOUT:
        init_ss = make_init_ss(init, network, spec,
                               start)  # returns None if timeout

        proven_safe = False
        try_quick = Settings.TRY_QUICK_OVERAPPROX or Settings.SINGLE_SET

        if init_ss is not None and try_quick and spec is not None:
            proven_safe, concrete_io_tuple = try_quick_overapprox(
                init_ss, network, spec, start, found_adv)

    if concrete_io_tuple is not None:
        # non-parallel adversarial example was generated
        if Settings.PRINT_OUTPUT:
            print("Proven unsafe before enumerate")

        rv = Result(network, quick=True)
        rv.result_str = 'unsafe'

        rv.cinput = concrete_io_tuple[0]
        rv.coutput = concrete_io_tuple[1]
    elif init_ss is None or time.perf_counter() - start > Settings.TIMEOUT:
        if Settings.PRINT_OUTPUT:
            print(
                f"Timeout before enumerate, init_ss is None: {init_ss is None}"
            )

        rv = Result(network, quick=True)
        rv.result_str = 'timeout'
    elif proven_safe:
        if Settings.PRINT_OUTPUT:
            print("Proven safe before enumerate")

        rv = Result(network, quick=True)
        rv.result_str = 'safe'
    else:
        concrete_io_tuple = None

        if p is not None and q is not None:
            concrete_io_tuple = q.get()
            p.join()
            p.terminate()
            q.cancel_join_thread()
            p = None
            q = None

        if concrete_io_tuple is not None:
            if Settings.PRINT_OUTPUT:
                print("Initial quick adversarial search found unsafe image.")

            rv = Result(network, quick=True)
            rv.result_str = 'unsafe'

            rv.cinput = concrete_io_tuple[0]
            rv.coutput = concrete_io_tuple[1]
        elif Settings.SINGLE_SET:
            if Settings.PRINT_OUTPUT:
                print("SINGLE_SET analysis inconclusive.")

            rv = Result(network, quick=True)
            rv.result_str = 'none'
        else:
            num_workers = 1 if Settings.NUM_PROCESSES < 1 else Settings.NUM_PROCESSES

            shared = SharedState(network, spec, num_workers, start)
            shared.push_init(init_ss)

            if shared.result.result_str != 'safe':  # easy specs can be proven safe in push_init()
                Timers.tic('run workers')

                if num_workers == 1:
                    if Settings.PRINT_OUTPUT:
                        print("Running single-threaded")

                    worker_func(0, shared)
                else:
                    processes = []

                    if Settings.PRINT_OUTPUT:
                        print(
                            f"Running in parallel with {num_workers} processes"
                        )

                    for index in range(Settings.NUM_PROCESSES):
                        p = multiprocessing.Process(target=worker_func,
                                                    args=(index, shared))
                        p.start()
                        processes.append(p)

                    for p in processes:
                        p.join()

                Timers.toc('run workers')

            rv = shared.result
            rv.total_secs = time.perf_counter() - start
            process_result(shared)

    if p is not None and q is not None:
        q.cancel_join_thread()
        p.terminate()
        p = None
        q = None

    if rv.total_secs is None:
        rv.total_secs = time.perf_counter() - start

    Timers.toc('enumerate_network')

    if Settings.TIMING_STATS and Settings.PRINT_OUTPUT and rv.result_str != 'error':
        Timers.print_stats()

    return rv
示例#2
0
def worker_func(worker_index, shared):
    'worker function during verification'

    np.seterr(
        all='raise'
    )  # raise exceptions on floating-point errors instead of printing warnings

    if shared.multithreaded:
        reinit_onnx_sessions(shared.network)
        Timers.stack.clear()  # reset inherited Timers
        tag = f" (Process {worker_index})"
    else:
        tag = ""

    timer_name = f'worker_func{tag}'

    Timers.tic(timer_name)

    priv = PrivateState(worker_index)
    priv.start_time = shared.start_time
    w = Worker(shared, priv)

    if worker_index == 1 and Settings.ADVERSARIAL_IN_WORKERS and Settings.ADVERSARIAL_ONNX_PATH:
        # while worker 0 does overapproximation, worker 1
        priv.agen, aimage = try_quick_adversarial(1)

        for i in range(Settings.ADVERSARIAL_WORKERS_MAX_ITER):

            if aimage is not None:
                if Settings.PRINT_OUTPUT:
                    print(
                        f"mixed_adversarial worker {worker_index} found unsafe image after on iteration {i}"
                    )

                flat_image = nn_flatten(aimage)

                output = w.shared.network.execute(flat_image)
                flat_output = np.ravel(output)

                olabel = np.argmax(output)
                confirmed = olabel != Settings.ADVERSARIAL_ORIG_LABEL

                if Settings.PRINT_OUTPUT:
                    print(
                        f"Original label: {Settings.ADVERSARIAL_ORIG_LABEL}, output argmax: {olabel}"
                    )
                    print(f"counterexample was confirmed: {confirmed}")

                if confirmed:
                    concrete_io_tuple = (flat_image, flat_output)
                    w.found_unsafe(concrete_io_tuple)
                    break

            if shared.should_exit.value != 0:
                break

            #if shared.finished_initial_overapprox.value == 1 and worker_index != 1:
            # worker 1 finishes all attempts, other works help with enumeration
            #    break

            # try again using a mixed strategy
            random_attacks_only = False
            aimage = priv.agen.try_mixed_adversarial(i, random_attacks_only)

    try:
        w.main_loop()

        if worker_index == 0 and Settings.PRINT_OUTPUT:
            print("\n")

            if Settings.SAVE_BRANCH_TUPLES_FILENAME is not None:
                with open(Settings.SAVE_BRANCH_TUPLES_FILENAME, 'w') as f:
                    for line in w.priv.branch_tuples_list:
                        f.write(f'{line}\n')

                    if not Settings.TIMING_STATS:
                        f.write(
                            f"\nNo timing stats recorded because Settings.TIMING_STATS was False"
                        )
                    else:
                        f.write("\nStats:\n")

                        as_timer_list = Timers.top_level_timer.get_children_recursive(
                            'advance')
                        fs_timer_list = Timers.top_level_timer.get_children_recursive(
                            'finished_star')
                        to_timer_list = Timers.top_level_timer.get_children_recursive(
                            'do_overapprox_rounds')

                        if as_timer_list:
                            as_timer = as_timer_list[0]
                            exact_secs = as_timer.total_secs

                            if fs_timer_list:
                                exact_secs += fs_timer_list[0].total_secs
                        else:
                            exact_secs = 0

                        if to_timer_list:
                            to_timer = to_timer_list[0]
                            o_secs = to_timer.total_secs
                        else:
                            o_secs = 0

                        total_secs = exact_secs + o_secs

                        f.write(f"Total time: {round(total_secs, 3)} ({round(o_secs, 3)} overapprox, " + \
                                f"{round(exact_secs, 3)} exact)\n")

                        t = round(w.priv.total_overapprox_ms / 1000, 3)
                        f.write(
                            f"Sum total time for ONLY safe overapproxations (optimal): {t}\n"
                        )

        Timers.toc(timer_name)

        if shared.multithreaded and not shared.had_exception.value:
            if worker_index != 0 and Settings.PRINT_OUTPUT and Settings.TIMING_STATS:
                time.sleep(
                    0.2
                )  # delay to try to let worker 0 print timing stats first

            ##############################
            shared.mutex.acquire()
            # use mutex so printing doesn't get interrupted

            # computation time is sum of advance_star and finished_star
            if Settings.TIMING_STATS:
                as_timer_list = Timers.top_level_timer.get_children_recursive(
                    'advance')
                fs_timer_list = Timers.top_level_timer.get_children_recursive(
                    'finished_star')
                to_timer_list = Timers.top_level_timer.get_children_recursive(
                    'do_overapprox_rounds')

                as_secs = as_timer_list[0].total_secs if as_timer_list else 0
                fs_secs = fs_timer_list[0].total_secs if fs_timer_list else 0
                to_secs = to_timer_list[0].total_secs if to_timer_list else 0
                secs = as_secs + fs_secs

                exact_percent = 100 * secs / Timers.top_level_timer.total_secs
                over_percent = 100 * to_secs / Timers.top_level_timer.total_secs
                sum_percent = exact_percent + over_percent

                if Settings.PRINT_OUTPUT:
                    if w.priv.total_fulfillment_count > 0:
                        t = w.priv.total_fulfillment_time
                    else:
                        t = 0

                    e_stars = w.priv.finished_stars
                    a_stars = w.priv.finished_approx_stars
                    tot_stars = e_stars + a_stars
                    print(f"Worker {worker_index}: {tot_stars} stars ({e_stars} exact, {a_stars} approx); " + \
                          f"Working: {round(sum_percent, 1)}% (Exact: {round(exact_percent, 1)}%, " + \
                          f"Overapprox: {round(over_percent, 1)}%); " + \
                          f"Waiting: {round(1000*t, 3)}ms ")
            shared.mutex.release()
            ##############################

            if Settings.PRINT_OUTPUT and Settings.TIMING_STATS and \
               Settings.NUM_PROCESSES > 1 and worker_index == 0:
                time.sleep(0.4)
                print("")
                Timers.print_stats()
                print("")
    except:
        if Settings.PRINT_OUTPUT:
            print("\n")
            traceback.print_exc()

        shared.mutex.acquire()
        shared.had_exception.value = True
        shared.should_exit.value = True
        shared.mutex.release()

        print(f"\nWorker {worker_index} had exception")
        w.clear_remaining_work()

        # dump branch tuples
        if Settings.SAVE_BRANCH_TUPLES_FILENAME is not None:
            with open(Settings.SAVE_BRANCH_TUPLES_FILENAME, 'w') as f:
                for line in w.priv.branch_tuples_list:
                    f.write(f'{line}\n')

        # fix timers
        while Timers.stack and Timers.stack[-1].name != timer_name:
            Timers.toc(Timers.stack[-1].name)

        Timers.toc(timer_name)