Example #1
0
    def transform_star(self, star):
        'transform the star'

        if star.a_mat is None:
            dims = star.lpi.get_num_cols()
            star.a_mat = np.identity(dims, dtype=self.dtype)
            star.bias = np.zeros(dims, dtype=self.dtype)

        cols = []

        for col in range(star.a_mat.shape[1]):
            #print(f".transforming star: {col} / {star.a_mat.shape[1]})")
            vec = star.a_mat[:, col]
            vec = nn_unflatten(vec, self.input_shape)

            res = self.execute(vec)
            res = res - self.zero_output
            res = nn_flatten(res)

            cols.append(res)

        dtype = star.bias.dtype
        star.a_mat = np.array(cols, dtype=dtype).transpose()

        vec = nn_unflatten(star.bias, self.input_shape)
        res = self.execute(vec)
        star.bias = nn_flatten(res)
Example #2
0
def make_init_box(min_image, max_image):
    'make init box'

    flat_min_image = nn_flatten(min_image)
    flat_max_image = nn_flatten(max_image)

    assert flat_min_image.size == flat_max_image.size

    box = list(zip(flat_min_image, flat_max_image))

    return box
Example #3
0
def make_init(nn, image_filename, epsilon, specific_image=None):
    'returns list of (image_id, image_data, classification_label, init_star_state, spec)'

    rv = []

    images, labels = load_unscaled_images(image_filename,
                                          specific_image=specific_image)
    min_images, _ = load_unscaled_images(image_filename,
                                         specific_image=specific_image,
                                         epsilon=-epsilon)
    max_images, _ = load_unscaled_images(image_filename,
                                         specific_image=specific_image,
                                         epsilon=epsilon)

    print("making init states")

    for image_id, (image, classification) in enumerate(zip(images, labels)):
        output = nn.execute(image)
        flat_output = nn_flatten(output)

        num_outputs = flat_output.shape[0]
        label = np.argmax(flat_output)

        if label == labels[image_id]:
            # correctly classified

            # unsafe if classification is not maximal (anything else is > classfication)
            spec_list = []

            for i in range(num_outputs):
                if i == classification:
                    continue

                l = [0] * 10

                l[classification] = 1
                l[i] = -1

                spec_list.append(Specification([l], [0]))

            spec = DisjunctiveSpec(spec_list)

            min_image = min_images[image_id]
            max_image = max_images[image_id]

            init_box = make_init_box(min_image, max_image)
            init_box = np.array(init_box, dtype=np.float32)
            init_state = LpStarState(init_box, spec)

            image_index = image_id

            if specific_image is not None:
                image_index = specific_image

            rv.append((image_index, image, label, init_state, spec))

    return rv
Example #4
0
    def transform_zono(self, zono):
        'transform the zono'

        zono_copy = zono.deep_copy()

        cols = []

        for col in range(zono.mat_t.shape[1]):
            #print(f".transforming zono: {col} / {zono.mat_t.shape[1]})")
            vec = zono.mat_t[:, col]
            vec = nn_unflatten(vec, self.input_shape)

            res = self.execute(vec)
            res = res - self.zero_output
            res = nn_flatten(res)

            cols.append(res)

        dtype = zono.center.dtype
        zono.mat_t = np.array(cols, dtype=dtype).transpose()

        start_center = nn_unflatten(zono.center, self.input_shape)
        end_center = self.execute(start_center)
        zono.center = nn_flatten(end_center)
Example #5
0
def make_prerelu_sims(ss, network):
    '''compute the prerelu simulation values at each remaining layer

    this only saves the state for the remaining layers, before relu is executed
    the output of the network is stored at index len(network.layers)

    returns a dict, layer_num -> sim_vector
    '''

    if ss.prefilter.simulation is None:
        rv = None
    else:
        rv = {}

        state = ss.prefilter.simulation[1].copy()
        layer_num = ss.cur_layer
        layer = network.layers[layer_num]

        rv[layer_num] = state

        # current layer may be partially processed
        if isinstance(layer, ReluLayer):
            state = np.clip(state, 0, np.inf)

        while layer_num + 1 < len(network.layers):
            layer_num += 1

            layer = network.layers[layer_num]
            rv[layer_num] = state

            shape = layer.get_input_shape()
            input_tensor = nn_unflatten(state,
                                        shape).astype(ss.star.a_mat.dtype)
            output_tensor = layer.execute(input_tensor)
            state = nn_flatten(output_tensor)

        # save final output
        rv[len(network.layers)] = state

    return rv
Example #6
0
def gen_adv_single_threaded(network, remaining_secs):
    'gen adversarial without multiprocessing interface'

    concrete_io_tuple = None

    start = time.perf_counter()
    _, aimage = try_quick_adversarial(Settings.ADVERSARIAL_QUICK_NUM_ATTEMPTS,
                                      remaining_secs)
    gen_time = time.perf_counter() - start

    if aimage is not None:
        if Settings.PRINT_OUTPUT:
            print("try_quick_adversarial found unsafe image")

        start = time.perf_counter()

        output = network.execute(aimage)
        flat_output = np.ravel(output)

        olabel = np.argmax(output)
        confirmed = olabel != Settings.ADVERSARIAL_ORIG_LABEL
        exec_time = time.perf_counter() - start

        if Settings.PRINT_OUTPUT:
            print(
                f"Original label: {Settings.ADVERSARIAL_ORIG_LABEL}, output argmax: {olabel}"
            )

            gen_ms = f"{round(1000*gen_time, 1)}ms"
            exec_ms = f"{round(1000*exec_time, 1)}ms"
            print(
                f"counterexample was confirmed: {confirmed}. Gen: {gen_ms}, Exec: {exec_ms}"
            )

        if confirmed:
            concrete_io_tuple = (nn_flatten(aimage), flat_output)

    return concrete_io_tuple
Example #7
0
    def find_concrete_io(self, star, branch_tuples):
        'try to find a concrete input and output in star, that explores the passed-in branch_tuples'

        assert Settings.FIND_CONCRETE_COUNTEREXAMPLES

        Timers.tic('find_concrete_io')
        rv = None

        # solve lp to get the input/output
        res = star.minimize_vec(None, return_io=True)

        if res is not None:
            cinput, _ = res

            # try to confirm the counter-example
            full_cinput_flat = star.to_full_input(cinput).astype(
                star.a_mat.dtype)

            exec_output, exec_branch_list = self.shared.network.execute(
                full_cinput_flat, save_branching=True)
            exec_output = nn_flatten(exec_output)

            if branch_list_in_branch_tuples(exec_branch_list, branch_tuples):
                rv = full_cinput_flat, exec_output
            else:
                print(
                    ". weakly-tested code: couldn't confirm countereample... tightening constraints"
                )

                # try to make each of the constraints a little tighter
                star_copy = star.copy()

                rhs_original = star_copy.lpi.get_rhs()
                rhs = rhs_original.copy()

                # tighten by this factor
                tighten_factor = 1e-16

                while tighten_factor < 1e16:

                    # tighten the constraints a little
                    for i, val in enumerate(rhs_original):
                        rhs[i] = val - tighten_factor

                    star_copy.lpi.set_rhs(rhs)

                    res = star_copy.minimize_vec(None, return_io=True)

                    if res is None:
                        # infeasible
                        break

                    cinput, _ = res

                    full_cinput_flat = star.to_full_input(cinput)
                    full_cinput = nn_unflatten(
                        full_cinput_flat,
                        self.shared.network.get_input_shape())
                    exec_output, exec_branch_list = self.shared.network.execute(
                        full_cinput, save_branching=True)
                    exec_output = nn_flatten(exec_output)

                    if branch_list_in_branch_tuples(exec_branch_list,
                                                    branch_tuples):
                        rv = full_cinput_flat, exec_output
                        break

                    # for next loop, tighten even more
                    tighten_factor *= 10

        Timers.toc('find_concrete_io')

        return rv
Example #8
0
    def consider_overapprox(self):
        '''conditionally run overapprox analysis

        this may set self.priv.ss to None if overapprox is safe

        returns is_safe (False does not mean unsafe, just that safe cannot be proven or timeout)
        '''

        do_overapprox = False
        is_safe = False
        concrete_io_tuple = None
        ss = self.priv.ss
        network = self.shared.network
        spec = self.shared.spec

        assert ss.remaining_splits() > 0

        if Settings.BRANCH_MODE == Settings.BRANCH_OVERAPPROX:
            do_overapprox = True
        elif Settings.BRANCH_MODE in [
                Settings.BRANCH_EGO, Settings.BRANCH_EGO_LIGHT
        ]:
            do_overapprox = ss.should_try_overapprox

        # todo: experiment moving this after single-zono overapprox
        if do_overapprox and Settings.SPLIT_IF_IDLE and self.exists_idle_worker(
        ):
            do_overapprox = False

        if do_overapprox:
            # todo: experiment global timeout vs per-round timeout
            start = time.perf_counter()

            def check_cancel_func():
                'worker cancel func. can raise OverapproxCanceledException'

                if self.shared.should_exit.value:
                    raise OverapproxCanceledException(
                        f'shared.should_exit was true')

                #if Settings.SPLIT_IF_IDLE and self.exists_idle_worker():
                #    print("cancel idle")
                #    raise OverapproxCanceledException('exists idle worker')

                now = time.perf_counter()

                if now - self.priv.start_time > Settings.TIMEOUT:
                    raise OverapproxCanceledException('timeout exceeded')

                if now - start > Settings.OVERAPPROX_LP_TIMEOUT:
                    raise OverapproxCanceledException('lp timeout exceeded')

            timer_name = 'do_overapprox_rounds'
            Timers.tic(timer_name)

            # compute simulation first (and make sure it's safe)
            prerelu_sims = make_prerelu_sims(ss, network)

            if prerelu_sims is None:
                concrete_io_tuple = None
            else:
                sim_out = prerelu_sims[len(network.layers)]

                if spec.is_violation(sim_out):
                    sim_in_flat = ss.prefilter.simulation[0]
                    sim_in = ss.star.to_full_input(sim_in_flat)
                    sim_in = sim_in.astype(ss.star.a_mat.dtype)

                    # run through complete network in to out before counting it
                    sim_out = network.execute(sim_in)
                    sim_out = nn_flatten(sim_out)

                    if spec.is_violation(sim_out):
                        concrete_io_tuple = [sim_in_flat, sim_out]

                        if Settings.PRINT_OUTPUT:
                            print(
                                "\nOverapproximation found was a confirmed counterexample."
                            )
                            print(
                                f"\nUnsafe Base Branch: {self.priv.ss.branch_str()} (Mode: {Settings.BRANCH_MODE})"
                            )

                        self.found_unsafe(concrete_io_tuple)
                        self.add_branch_str('CONCRETE UNSAFE')

            if concrete_io_tuple is None:
                # sim was safe, proceed with overapproximation

                try:
                    gen_limit = max(self.priv.max_approx_gen,
                                    Settings.OVERAPPROX_MIN_GEN_LIMIT)

                    if Settings.OVERAPPROX_GEN_LIMIT_MULTIPLIER is None:
                        gen_limit = np.inf

                    num_branches = len(ss.branch_tuples)
                    if num_branches > Settings.OVERAPPROX_NEAR_ROOT_MAX_SPLITS:
                        otypes = Settings.OVERAPPROX_TYPES
                    else:
                        otypes = Settings.OVERAPPROX_TYPES_NEAR_ROOT

                    res = do_overapprox_rounds(
                        ss,
                        network,
                        spec,
                        prerelu_sims,
                        check_cancel_func,
                        gen_limit,
                        try_seeded_adversarial=self.try_seeded_adversarial,
                        overapprox_types=otypes)

                    if res.concrete_io_tuple is not None:
                        if Settings.PRINT_OUTPUT:
                            print(
                                "\nviolation star found adversarial was a confirmed counterexample."
                            )
                            print(
                                f"\nUnsafe Base Branch: {self.priv.ss.branch_str()} (Mode: {Settings.BRANCH_MODE})"
                            )

                        self.found_unsafe(res.concrete_io_tuple)
                        self.add_branch_str('CONCRETE UNSAFE')
                    else:

                        is_safe = res.is_safe
                        safe_str = "safe" if is_safe else "unsafe"
                        self.add_branch_str(f"{safe_str} {res}")

                        if not is_safe:
                            ss.should_try_overapprox = False

                        if Settings.OVERAPPROX_GEN_LIMIT_MULTIPLIER is not None:
                            if is_safe:
                                new_max = Settings.OVERAPPROX_GEN_LIMIT_MULTIPLIER * res.get_max_gens(
                                )
                                self.priv.max_approx_gen = max(
                                    self.priv.max_approx_gen, new_max)
                            else:
                                self.priv.max_approx_gen = 0  # reset limit

                except OverapproxCanceledException as e:
                    # fix timers
                    while Timers.stack and Timers.stack[-1].name != timer_name:
                        Timers.toc(Timers.stack[-1].name)

                    self.add_branch_str(f'{str(e)}')
                    self.priv.max_approx_gen = 0  # reset limit

                    ss.should_try_overapprox = False

            Timers.toc(timer_name)

            ##### post overapproximation processing

            if Settings.PRINT_BRANCH_TUPLES:
                print(self.priv.branch_tuples_list[-1])

            if is_safe or concrete_io_tuple is not None:
                # done with this branch

                if Settings.RESULT_SAVE_POLYS:
                    self.save_poly(ss)

                self.priv.ss = None
                self.priv.finished_approx_stars += 1

                # local stats that get updated in update_shared_variables
                self.priv.update_stars += 1
                self.priv.update_work_frac += ss.work_frac
                self.priv.update_stars_in_progress -= 1

                if not self.priv.work_list:
                    # urgently update shared variables to try to get more work
                    self.priv.shared_update_urgent = True
                    self.priv.fulfillment_requested_time = time.perf_counter()

        return is_safe
Example #9
0
    def try_seeded_adversarial(self, dims, abstract_ios):
        '''
        generate adversarial image from abstract counterexample seeds

        returns concrete_io_tuple or None
        '''

        Timers.tic('try_seeded_adversarial')

        assert dims == Settings.ADVERSARIAL_ORIG_IMAGE.size

        for cinput, _ in abstract_ios:

            seed_image = nn_unflatten(cinput[:dims],
                                      Settings.ADVERSARIAL_ORIG_IMAGE.shape)

            concrete_io_tuple = None

            onnx_path = Settings.ADVERSARIAL_ONNX_PATH
            assert onnx_path is not None

            if self.priv.agen is None:
                # initialize
                ep = Settings.ADVERSARIAL_EPSILON
                im = Settings.ADVERSARIAL_ORIG_IMAGE
                l = Settings.ADVERSARIAL_ORIG_LABEL

                Timers.tic("AgenState init")
                self.priv.agen = AgenState(onnx_path, im, l, ep)
                Timers.toc("AgenState init")

            a = self.priv.agen.try_seeded(seed_image)

            if a is not None:
                aimage, ep = a

                if Settings.PRINT_OUTPUT:
                    print(
                        f"try_seeded_adversarial found violation image with ep={ep}"
                    )
            else:
                aimage = None

            if aimage is not None:
                flat_image = nn_flatten(aimage)

                output = self.shared.network.execute(flat_image)
                flat_output = np.ravel(output)

                olabel = np.argmax(output)
                confirmed = olabel != Settings.ADVERSARIAL_ORIG_LABEL

                if Settings.PRINT_OUTPUT:
                    print(
                        f"Original label: {Settings.ADVERSARIAL_ORIG_LABEL}, output argmax: {olabel}"
                    )
                    print(f"counterexample was confirmed: {confirmed}")

                if confirmed:
                    concrete_io_tuple = (flat_image, flat_output)
                    break

        Timers.toc('try_seeded_adversarial')

        return concrete_io_tuple
Example #10
0
def worker_func(worker_index, shared):
    'worker function during verification'

    np.seterr(
        all='raise'
    )  # raise exceptions on floating-point errors instead of printing warnings

    if shared.multithreaded:
        reinit_onnx_sessions(shared.network)
        Timers.stack.clear()  # reset inherited Timers
        tag = f" (Process {worker_index})"
    else:
        tag = ""

    timer_name = f'worker_func{tag}'

    Timers.tic(timer_name)

    priv = PrivateState(worker_index)
    priv.start_time = shared.start_time
    w = Worker(shared, priv)

    if worker_index == 1 and Settings.ADVERSARIAL_IN_WORKERS and Settings.ADVERSARIAL_ONNX_PATH:
        # while worker 0 does overapproximation, worker 1
        priv.agen, aimage = try_quick_adversarial(1)

        for i in range(Settings.ADVERSARIAL_WORKERS_MAX_ITER):

            if aimage is not None:
                if Settings.PRINT_OUTPUT:
                    print(
                        f"mixed_adversarial worker {worker_index} found unsafe image after on iteration {i}"
                    )

                flat_image = nn_flatten(aimage)

                output = w.shared.network.execute(flat_image)
                flat_output = np.ravel(output)

                olabel = np.argmax(output)
                confirmed = olabel != Settings.ADVERSARIAL_ORIG_LABEL

                if Settings.PRINT_OUTPUT:
                    print(
                        f"Original label: {Settings.ADVERSARIAL_ORIG_LABEL}, output argmax: {olabel}"
                    )
                    print(f"counterexample was confirmed: {confirmed}")

                if confirmed:
                    concrete_io_tuple = (flat_image, flat_output)
                    w.found_unsafe(concrete_io_tuple)
                    break

            if shared.should_exit.value != 0:
                break

            #if shared.finished_initial_overapprox.value == 1 and worker_index != 1:
            # worker 1 finishes all attempts, other works help with enumeration
            #    break

            # try again using a mixed strategy
            random_attacks_only = False
            aimage = priv.agen.try_mixed_adversarial(i, random_attacks_only)

    try:
        w.main_loop()

        if worker_index == 0 and Settings.PRINT_OUTPUT:
            print("\n")

            if Settings.SAVE_BRANCH_TUPLES_FILENAME is not None:
                with open(Settings.SAVE_BRANCH_TUPLES_FILENAME, 'w') as f:
                    for line in w.priv.branch_tuples_list:
                        f.write(f'{line}\n')

                    if not Settings.TIMING_STATS:
                        f.write(
                            f"\nNo timing stats recorded because Settings.TIMING_STATS was False"
                        )
                    else:
                        f.write("\nStats:\n")

                        as_timer_list = Timers.top_level_timer.get_children_recursive(
                            'advance')
                        fs_timer_list = Timers.top_level_timer.get_children_recursive(
                            'finished_star')
                        to_timer_list = Timers.top_level_timer.get_children_recursive(
                            'do_overapprox_rounds')

                        if as_timer_list:
                            as_timer = as_timer_list[0]
                            exact_secs = as_timer.total_secs

                            if fs_timer_list:
                                exact_secs += fs_timer_list[0].total_secs
                        else:
                            exact_secs = 0

                        if to_timer_list:
                            to_timer = to_timer_list[0]
                            o_secs = to_timer.total_secs
                        else:
                            o_secs = 0

                        total_secs = exact_secs + o_secs

                        f.write(f"Total time: {round(total_secs, 3)} ({round(o_secs, 3)} overapprox, " + \
                                f"{round(exact_secs, 3)} exact)\n")

                        t = round(w.priv.total_overapprox_ms / 1000, 3)
                        f.write(
                            f"Sum total time for ONLY safe overapproxations (optimal): {t}\n"
                        )

        Timers.toc(timer_name)

        if shared.multithreaded and not shared.had_exception.value:
            if worker_index != 0 and Settings.PRINT_OUTPUT and Settings.TIMING_STATS:
                time.sleep(
                    0.2
                )  # delay to try to let worker 0 print timing stats first

            ##############################
            shared.mutex.acquire()
            # use mutex so printing doesn't get interrupted

            # computation time is sum of advance_star and finished_star
            if Settings.TIMING_STATS:
                as_timer_list = Timers.top_level_timer.get_children_recursive(
                    'advance')
                fs_timer_list = Timers.top_level_timer.get_children_recursive(
                    'finished_star')
                to_timer_list = Timers.top_level_timer.get_children_recursive(
                    'do_overapprox_rounds')

                as_secs = as_timer_list[0].total_secs if as_timer_list else 0
                fs_secs = fs_timer_list[0].total_secs if fs_timer_list else 0
                to_secs = to_timer_list[0].total_secs if to_timer_list else 0
                secs = as_secs + fs_secs

                exact_percent = 100 * secs / Timers.top_level_timer.total_secs
                over_percent = 100 * to_secs / Timers.top_level_timer.total_secs
                sum_percent = exact_percent + over_percent

                if Settings.PRINT_OUTPUT:
                    if w.priv.total_fulfillment_count > 0:
                        t = w.priv.total_fulfillment_time
                    else:
                        t = 0

                    e_stars = w.priv.finished_stars
                    a_stars = w.priv.finished_approx_stars
                    tot_stars = e_stars + a_stars
                    print(f"Worker {worker_index}: {tot_stars} stars ({e_stars} exact, {a_stars} approx); " + \
                          f"Working: {round(sum_percent, 1)}% (Exact: {round(exact_percent, 1)}%, " + \
                          f"Overapprox: {round(over_percent, 1)}%); " + \
                          f"Waiting: {round(1000*t, 3)}ms ")
            shared.mutex.release()
            ##############################

            if Settings.PRINT_OUTPUT and Settings.TIMING_STATS and \
               Settings.NUM_PROCESSES > 1 and worker_index == 0:
                time.sleep(0.4)
                print("")
                Timers.print_stats()
                print("")
    except:
        if Settings.PRINT_OUTPUT:
            print("\n")
            traceback.print_exc()

        shared.mutex.acquire()
        shared.had_exception.value = True
        shared.should_exit.value = True
        shared.mutex.release()

        print(f"\nWorker {worker_index} had exception")
        w.clear_remaining_work()

        # dump branch tuples
        if Settings.SAVE_BRANCH_TUPLES_FILENAME is not None:
            with open(Settings.SAVE_BRANCH_TUPLES_FILENAME, 'w') as f:
                for line in w.priv.branch_tuples_list:
                    f.write(f'{line}\n')

        # fix timers
        while Timers.stack and Timers.stack[-1].name != timer_name:
            Timers.toc(Timers.stack[-1].name)

        Timers.toc(timer_name)