def loss(x, r):
        hidden = known_T.forward(x, with_relu=True, np=jnp)
        dotted = matmul(hidden,
                        jnp.array(known_A)[:, r],
                        jnp.array(known_B)[r],
                        np=jnp)

        return jnp.sum(jnp.square(dotted))
Пример #2
0
 def check_fn(point):
     if partial_weights is None:
         return True
     hidden = matmul(known_T.forward(point, with_relu=True), partial_weights.T, partial_biases)
     if np.any(np.abs(hidden) < 1e-4):
         return False
     
     return True
def is_solution(input_tuple):
    signs, (known_A0, known_B0, LAYER, known_hidden_so_far, K,
            responses) = input_tuple
    new_signs = np.array(
        [-1 if x == '0' else 1 for x in bin((1 << K) + signs)[3:]])

    if CHEATING:
        if signs % 1001 == 0:
            print('tick', signs)
    else:
        if signs % 100001 == 0:
            # This isn't cheating, but makes things prettier
            print('tick', signs)

    guess_A0 = known_A0 * new_signs
    guess_B0 = known_B0 * new_signs
    # We're going to set up a system of equations here
    # The matrix is going to have a bunch of rows (equal to number of equations)
    # each row is of the form
    # [h_0 h_1 h_2 h_3 h_4 ... h_n 1]
    # where h_n is the hidden vector after multiplying by the guessed matrix.
    # and 1 is the weight for the bias term

    inputs = matmul(known_hidden_so_far, guess_A0, guess_B0)
    inputs[inputs < 0] = 0

    if responses is None:
        responses = np.ones((inputs.shape[0], 1))
    else:
        inputs = np.concatenate([inputs, np.ones((inputs.shape[0], 1))],
                                axis=1)
        pass

    solution, res, _, _ = scipy.linalg.lstsq(inputs, responses)

    bias = np.dot(inputs, solution) - responses
    res = np.std(bias)

    #print("Recovered vector", solution.flatten())

    if res > 1e-2:
        return (res, new_signs, solution), 0

    bias = bias.mean(axis=0)

    #solution = np.concatenate([solution, [-bias]])[:, np.newaxis]
    mat = (solution / solution[0][0])[:-1, :]
    if np.any(np.isnan(mat)) or np.any(np.isinf(mat)):
        print("Invalid solution")
        return (res, new_signs, solution), 0
    else:
        s = solution / solution[0][0]
        s[np.abs(s) < 1e-14] = 0

        return (res, new_signs, solution), 1
    def choose_new_direction_from_minimize(previous_axis):
        """
        Given the current point which is at a critical point of the next
        layer neuron, compute which direction we should travel to continue
        with finding more points on this hyperplane.

        Our goal is going to be to pick a direction that lets us explore
        a new part of the space we haven't seen before.
        """

        print("Choose a new direction to travel in")
        if len(history) == 0:
            which_to_change = 0
            new_perp_dir = perp_dir
            new_start_point = start_point
            initial_signs = get_polytope_at(known_T, known_A, known_B,
                                            start_point)

            # If we're in the 1 region of the polytope then we try to make it smaller
            # otherwise make it bigger
            fn = min if initial_signs[0] == 1 else max
        else:
            neuron_values = np.array([x[1] for x in history])

            neuron_positive_count = np.sum(neuron_values > 1, axis=0)
            neuron_negative_count = np.sum(neuron_values < -1, axis=0)

            mean_plus_neuron_value = neuron_positive_count / (
                neuron_positive_count + neuron_negative_count + 1)
            mean_minus_neuron_value = neuron_negative_count / (
                neuron_positive_count + neuron_negative_count + 1)

            # we want to find values that are consistently 0 or 1
            # So map 0 -> 0 and 1 -> 0 and the middle to higher values
            if only_need_positive:
                neuron_consistency = mean_plus_neuron_value
            else:
                neuron_consistency = mean_plus_neuron_value * mean_minus_neuron_value

            # Print out how much progress we've made.
            # This estimate is probably worse than Windows 95's estimated time remaining.
            # At least it's monotonic. Be thankful for that.
            print("Progress",
                  "%.1f" % int(np.mean(neuron_consistency != 0) * 100) + "%")
            print("Counts on each side of each neuron")
            print(neuron_positive_count)
            print(neuron_negative_count)

            # Choose the smallest value, which is the most consistent
            which_to_change = np.argmin(neuron_consistency)

            print("Try to explore the other side of neuron", which_to_change)

            if which_to_change != previous_axis:
                if previous_axis is not None and neuron_consistency[
                        previous_axis] == neuron_consistency[which_to_change]:
                    # If the previous thing we were working towards has the same value as this one
                    # the don't change our mind and just keep going at that one
                    # (almost always--sometimes we can get stuck, let us get unstuck)
                    which_to_change = previous_axis
                    new_start_point = start_point
                    new_perp_dir = perp_dir
                else:
                    valid_axes = np.where(
                        neuron_consistency ==
                        neuron_consistency[which_to_change])[0]

                    best = (np.inf, None, None)

                    for _, potential_hidden_vector, potential_point in history[
                            -1:]:
                        for potential_axis in valid_axes:
                            value = potential_hidden_vector[potential_axis]
                            if np.abs(value) < best[0]:
                                best = (np.abs(value), potential_axis,
                                        potential_point)

                    _, which_to_change, new_start_point = best
                    new_perp_dir = perp_dir

            else:
                new_start_point = start_point
                new_perp_dir = perp_dir

            # If we're in the 1 region of the polytope then we try to make it smaller
            # otherwise make it bigger
            fn = min if neuron_positive_count[
                which_to_change] > neuron_negative_count[
                    which_to_change] else max
            arg_fn = np.argmin if neuron_positive_count[
                which_to_change] > neuron_negative_count[
                    which_to_change] else np.argmax
            print("Changing", which_to_change, 'to flip sides because mean is',
                  mean_plus_neuron_value[which_to_change])

        val = matmul(known_T.forward(new_start_point, with_relu=True), known_A,
                     known_B)[which_to_change]

        initial_signs = get_polytope_at(known_T, known_A, known_B,
                                        new_start_point)

        # Now we're going to figure out what direction makes this biggest/smallest
        # this doesn't take any queries
        # There's probably an analytical way to do this.
        # But thinking is hard. Just try 1000 random angles.
        # There are no queries involved in this process.

        choices = []
        for _ in range(1000):
            random_dir = np.random.normal(size=DIM)
            perp_component = np.dot(random_dir, new_perp_dir) / (np.dot(
                new_perp_dir, new_perp_dir)) * new_perp_dir
            parallel_dir = random_dir - perp_component

            # This is the direction we're going to travel in.
            go_direction = parallel_dir / np.sum(parallel_dir**2)**.5

            try:
                a_bit_further, high = binary_search_towards(
                    known_T, known_A, known_B, new_start_point, initial_signs,
                    go_direction)
            except AcceptableFailure:
                continue
            if a_bit_further is None:
                continue

            # choose a direction that makes the Kth value go down by the most
            val = matmul(
                known_T.forward(a_bit_further[np.newaxis, :], with_relu=True),
                known_A, known_B)[0][which_to_change]

            #print('\t', val, high)

            choices.append([val, new_start_point + high * go_direction])

        best_value, multiple_intersection_point = fn(choices,
                                                     key=lambda x: x[0])

        print('Value', best_value)
        return new_start_point, multiple_intersection_point, which_to_change
def follow_hyperplane(LAYER,
                      start_point,
                      known_T,
                      known_A,
                      known_B,
                      history=[],
                      MAX_POINTS=1e3,
                      only_need_positive=False):
    """
    This is the ugly algorithm that will let us recover sign for expansive networks.
    Assumes we have extracted up to layer K-1 correctly, and layer K up to sign.

    start_point is a neuron on layer K+1

    known_T is the transformation that computes up to layer K-1, with
    known_A and known_B being the layer K matrix up to sign.

    We're going to come up with a bunch of different inputs,
    each of which has the same critical point held constant at zero.
    """
    def choose_new_direction_from_minimize(previous_axis):
        """
        Given the current point which is at a critical point of the next
        layer neuron, compute which direction we should travel to continue
        with finding more points on this hyperplane.

        Our goal is going to be to pick a direction that lets us explore
        a new part of the space we haven't seen before.
        """

        print("Choose a new direction to travel in")
        if len(history) == 0:
            which_to_change = 0
            new_perp_dir = perp_dir
            new_start_point = start_point
            initial_signs = get_polytope_at(known_T, known_A, known_B,
                                            start_point)

            # If we're in the 1 region of the polytope then we try to make it smaller
            # otherwise make it bigger
            fn = min if initial_signs[0] == 1 else max
        else:
            neuron_values = np.array([x[1] for x in history])

            neuron_positive_count = np.sum(neuron_values > 1, axis=0)
            neuron_negative_count = np.sum(neuron_values < -1, axis=0)

            mean_plus_neuron_value = neuron_positive_count / (
                neuron_positive_count + neuron_negative_count + 1)
            mean_minus_neuron_value = neuron_negative_count / (
                neuron_positive_count + neuron_negative_count + 1)

            # we want to find values that are consistently 0 or 1
            # So map 0 -> 0 and 1 -> 0 and the middle to higher values
            if only_need_positive:
                neuron_consistency = mean_plus_neuron_value
            else:
                neuron_consistency = mean_plus_neuron_value * mean_minus_neuron_value

            # Print out how much progress we've made.
            # This estimate is probably worse than Windows 95's estimated time remaining.
            # At least it's monotonic. Be thankful for that.
            print("Progress",
                  "%.1f" % int(np.mean(neuron_consistency != 0) * 100) + "%")
            print("Counts on each side of each neuron")
            print(neuron_positive_count)
            print(neuron_negative_count)

            # Choose the smallest value, which is the most consistent
            which_to_change = np.argmin(neuron_consistency)

            print("Try to explore the other side of neuron", which_to_change)

            if which_to_change != previous_axis:
                if previous_axis is not None and neuron_consistency[
                        previous_axis] == neuron_consistency[which_to_change]:
                    # If the previous thing we were working towards has the same value as this one
                    # the don't change our mind and just keep going at that one
                    # (almost always--sometimes we can get stuck, let us get unstuck)
                    which_to_change = previous_axis
                    new_start_point = start_point
                    new_perp_dir = perp_dir
                else:
                    valid_axes = np.where(
                        neuron_consistency ==
                        neuron_consistency[which_to_change])[0]

                    best = (np.inf, None, None)

                    for _, potential_hidden_vector, potential_point in history[
                            -1:]:
                        for potential_axis in valid_axes:
                            value = potential_hidden_vector[potential_axis]
                            if np.abs(value) < best[0]:
                                best = (np.abs(value), potential_axis,
                                        potential_point)

                    _, which_to_change, new_start_point = best
                    new_perp_dir = perp_dir

            else:
                new_start_point = start_point
                new_perp_dir = perp_dir

            # If we're in the 1 region of the polytope then we try to make it smaller
            # otherwise make it bigger
            fn = min if neuron_positive_count[
                which_to_change] > neuron_negative_count[
                    which_to_change] else max
            arg_fn = np.argmin if neuron_positive_count[
                which_to_change] > neuron_negative_count[
                    which_to_change] else np.argmax
            print("Changing", which_to_change, 'to flip sides because mean is',
                  mean_plus_neuron_value[which_to_change])

        val = matmul(known_T.forward(new_start_point, with_relu=True), known_A,
                     known_B)[which_to_change]

        initial_signs = get_polytope_at(known_T, known_A, known_B,
                                        new_start_point)

        # Now we're going to figure out what direction makes this biggest/smallest
        # this doesn't take any queries
        # There's probably an analytical way to do this.
        # But thinking is hard. Just try 1000 random angles.
        # There are no queries involved in this process.

        choices = []
        for _ in range(1000):
            random_dir = np.random.normal(size=DIM)
            perp_component = np.dot(random_dir, new_perp_dir) / (np.dot(
                new_perp_dir, new_perp_dir)) * new_perp_dir
            parallel_dir = random_dir - perp_component

            # This is the direction we're going to travel in.
            go_direction = parallel_dir / np.sum(parallel_dir**2)**.5

            try:
                a_bit_further, high = binary_search_towards(
                    known_T, known_A, known_B, new_start_point, initial_signs,
                    go_direction)
            except AcceptableFailure:
                continue
            if a_bit_further is None:
                continue

            # choose a direction that makes the Kth value go down by the most
            val = matmul(
                known_T.forward(a_bit_further[np.newaxis, :], with_relu=True),
                known_A, known_B)[0][which_to_change]

            #print('\t', val, high)

            choices.append([val, new_start_point + high * go_direction])

        best_value, multiple_intersection_point = fn(choices,
                                                     key=lambda x: x[0])

        print('Value', best_value)
        return new_start_point, multiple_intersection_point, which_to_change

    ###################################################
    ### Actual code to do the sign recovery starts. ###
    ###################################################

    start_box_step = 0
    points_on_plane = []

    if CHEATING:
        layer = np.abs(
            cheat_get_inner_layers(np.array(start_point))[LAYER + 1])
        print("Layer", layer)
        which_is_zero = np.argmin(layer)

    current_change_axis = 0

    while True:
        print("\n\n")
        print("-----" * 10)

        if CHEATING:
            layer = np.abs(
                cheat_get_inner_layers(np.array(start_point))[LAYER + 1])
            #print('layer',LAYER+1, layer)
            #print('all inner layers')
            #for e in cheat_get_inner_layers(np.array(start_point)):
            #    print(e)
            which_is_zero_2 = np.argmin(np.abs(layer))

            if which_is_zero_2 != which_is_zero:
                print("STARTED WITH", which_is_zero, "NOW IS", which_is_zero_2)
                print(layer)
                raise

        # Keep track of where we've been, so we can go to new places.
        which_polytope = get_polytope_at(known_T, known_A, known_B,
                                         start_point, False)  # [-1 1 -1]
        hidden_vector = get_hidden_at(known_T, known_A, known_B, LAYER,
                                      start_point, False)
        sign_at_init = sign_to_int(which_polytope)  # 0b010 -> 2

        print("Number of collected points", len(points_on_plane))
        if len(points_on_plane) > MAX_POINTS:
            return points_on_plane, False

        neuron_values = np.array([x[1] for x in history])

        neuron_positive_count = np.sum(neuron_values > 1, axis=0)
        neuron_negative_count = np.sum(neuron_values < -1, axis=0)

        if (np.all(neuron_positive_count > 0) and np.all(neuron_negative_count > 0)) or \
           (only_need_positive and np.all(neuron_positive_count > 0)):
            print("Have all the points we need (1)")
            print(query_count)
            print(neuron_positive_count)
            print(neuron_negative_count)

            neuron_values = np.array([
                get_hidden_at(known_T, known_A, known_B, LAYER, x, False)
                for x in points_on_plane
            ])

            neuron_positive_count = np.sum(neuron_values > 1, axis=0)
            neuron_negative_count = np.sum(neuron_values < -1, axis=0)

            print(neuron_positive_count)
            print(neuron_negative_count)

            return points_on_plane, True

        # 1. find a way to move along the hyperplane by computing the normal
        # direction using the ratios function. Then find a parallel direction.

        try:
            #perp_dir = get_ratios([start_point], [range(DIM)], eps=1e-4)[0].flatten()
            perp_dir = get_ratios_lstsq(0, [start_point], [range(DIM)],
                                        KnownT([], []),
                                        eps=1e-5)[0].flatten()

        except AcceptableFailure:
            print(
                "Failed to compute ratio at start point. Something very bad happened."
            )
            return points_on_plane, False

        # Record these points.
        history.append((which_polytope, hidden_vector, np.copy(start_point)))

        # We can't just pick any parallel direction. If we did, then we would
        # not end up covering much of the input space.

        # Instead, we're going to figure out which layer-1 hyperplanes are "visible"
        # from the current point. Then we're going to try and go reach all of them.

        # This is the point at which the first and second layers intersect.
        start_point, multiple_intersection_point, new_change_axis = choose_new_direction_from_minimize(
            current_change_axis)

        if new_change_axis != current_change_axis:
            start_point, multiple_intersection_point, current_change_axis = choose_new_direction_from_minimize(
                None)

        #if CHEATING:
        #    print("INIT MULTIPLE", cheat_get_inner_layers(multiple_intersection_point))

        # Refine the direction we're going to travel in---stay numerically stable.
        towards_multiple_direction = multiple_intersection_point - start_point
        step_distance = np.sum(towards_multiple_direction**2)**.5

        print("Distance we need to step:", step_distance)

        if step_distance > 1 or True:
            mid_point = 1e-4 * towards_multiple_direction / np.sum(
                towards_multiple_direction**2)**.5 + start_point

            random_dir = np.random.normal(size=DIM)

            mid_points = do_better_sweep(mid_point,
                                         perp_dir / np.sum(perp_dir**2)**.5,
                                         low=-1e-3,
                                         high=1e-3,
                                         known_T=known_T)

            if len(mid_points) > 0:
                mid_point = mid_points[np.argmin(
                    np.sum((mid_point - mid_points)**2, axis=1))]

                towards_multiple_direction = mid_point - start_point
                towards_multiple_direction = towards_multiple_direction / np.sum(
                    towards_multiple_direction**2)**.5

                initial_signs = get_polytope_at(known_T, known_A, known_B,
                                                start_point)
                _, high = binary_search_towards(known_T, known_A, known_B,
                                                start_point, initial_signs,
                                                towards_multiple_direction)

                multiple_intersection_point = towards_multiple_direction * high + start_point

        # Find the angle of the next hyperplane
        # First, take random steps away from the intersection point
        # Then run the search algorithm to find some intersections
        # what we find will either be a layer-1 or layer-2 intersection.

        print("Now try to find the continuation direction")
        success = None
        while success is None:
            if start_box_step < 0:
                start_box_step = 0
                print("VERY BAD FAILURE")
                print("Choose a new random point to start from")
                which_point = np.random.randint(0, len(history))
                start_point = history[which_point][2]
                print("New point is", which_point)
                current_change_axis = np.random.randint(0, sizes[LAYER + 1])
                print("New axis to change", current_change_axis)
                break

            print("\tStart the box step with size", start_box_step)
            try:
                success, camefrom, stepsize = find_plane_angle(
                    known_T, known_A, known_B, multiple_intersection_point,
                    sign_at_init, start_box_step)
            except AcceptableFailure:
                # Go back to the top and try with a new start point
                print("\tOkay we need to try with a new start point")
                start_box_step = -10

            start_box_step -= 2

        if success is None:
            continue

        val = matmul(
            known_T.forward(multiple_intersection_point, with_relu=True),
            known_A, known_B)[new_change_axis]
        print("Value at multiple:", val)
        val = matmul(known_T.forward(success, with_relu=True), known_A,
                     known_B)[new_change_axis]
        print("Value at success:", val)

        if stepsize < 10:
            new_move_direction = success - multiple_intersection_point

            # We don't want to be right next to the multiple intersection point.
            # So let's binary search to find how far away we can go while remaining in this polytope.
            # Then we'll go half as far as we can maximally go.

            initial_signs = get_polytope_at(known_T, known_A, known_B, success)
            print("polytope at initial", sign_to_int(initial_signs))
            low = 0
            high = 1
            while high - low > 1e-2:
                mid = (high + low) / 2
                query_point = multiple_intersection_point + mid * new_move_direction
                next_signs = get_polytope_at(known_T, known_A, known_B,
                                             query_point)
                print(
                    "polytope at", mid, sign_to_int(next_signs), "%x" %
                    (sign_to_int(next_signs) ^ sign_to_int(initial_signs)))
                if initial_signs == next_signs:
                    low = mid
                else:
                    high = mid
            print("GO TO", mid)

            success = multiple_intersection_point + (mid /
                                                     2) * new_move_direction

            val = matmul(known_T.forward(success, with_relu=True), known_A,
                         known_B)[new_change_axis]
            print("Value at moved success:", val)

        print("Adding the points to the set of known good points")

        points_on_plane.append(start_point)

        if camefrom is not None:
            points_on_plane.append(camefrom)
        #print("Old start point", start_point)
        #print("Set to success", success)
        start_point = success
        start_box_step = max(stepsize - 1, 0)

    return points_on_plane, False
def improve_row_precision(args):
    """
    Improve the precision of an extracted row.
    We think we know where it is, but let's actually figure it out for sure.

    To do this, start by sampling a bunch of points near where we expect the line to be.
    This gives us a picture like this

                      X
                       X
                    
                   X
               X
                 X
                X

    Where some are correct and some are wrong.
    With some robust statistics, try to fit a line that fits through most of the points
    (in high dimension!)

                      X
                     / X
                    / 
                   X
               X  /
                 /
                X

    This solves the equation and improves the point for us.
    """
    (LAYER, known_T, known_A, known_B, row, did_again) = args
    logger.log("Improve the extracted neuron number", row, level=Logger.INFO)

    logger.log(np.sum(np.abs(known_A[:, row])), level=Logger.INFO)
    if np.sum(np.abs(known_A[:, row])) < 1e-8:
        return known_A[:, row], known_B[row]

    def loss(x, r):
        hidden = known_T.forward(x, with_relu=True, np=jnp)
        dotted = matmul(hidden,
                        jnp.array(known_A)[:, r],
                        jnp.array(known_B)[r],
                        np=jnp)

        return jnp.sum(jnp.square(dotted))

    loss_grad = jax.jit(jax.grad(loss))
    loss = jax.jit(loss)

    extended_T = known_T.extend_by(known_A, known_B)

    def get_more_points(NUM):
        """
        Gather more points. This procedure is really kind of ugly and should probably be fixed.
        We want to find points that are near where we expect them to be.

        So begin by finding preimages to points that are on the line with gradient descent.
        This should be completely possible, because we have d_0 input dimensions but 
        only want to control one inner layer.
        """
        logger.log("Gather some more actual critical points on the plane",
                   level=Logger.INFO)
        stepsize = .1
        critical_points = []
        while len(critical_points) <= NUM:
            logger.log("On this iteration I have ",
                       len(critical_points),
                       "critical points on the plane",
                       level=Logger.INFO)
            points = np.random.normal(0, 1e3, size=(
                100,
                DIM,
            ))

            lr = 10
            for step in range(5000):
                # Use JaX's built in optimizer to do this.
                # We want to adjust the LR so that we get a better solution
                # as we optimize. Probably there is a better way to do this,
                # but this seems to work just fine.

                # No queries involvd here.
                if step % 1000 == 0:
                    lr *= .5
                    init, opt_update, get_params = jax.experimental.optimizers.adam(
                        lr)

                    @jax.jit
                    def update(i, opt_state, batch):
                        params = get_params(opt_state)
                        return opt_update(i, loss_grad(batch, row), opt_state)

                    opt_state = init(points)

                if step % 100 == 0:
                    ell = loss(points, row)
                    if CHEATING:
                        # This isn't cheating, but makes things prettier
                        print(ell)
                    if ell < 1e-5:
                        break
                opt_state = update(step, opt_state, points)
                points = opt_state.packed_state[0][0]

            for point in points:
                # For each point, try to see where it actually is.

                # First, if optimization failed, then abort.
                if loss(point, row) > 1e-5:
                    continue

                if LAYER > 0:
                    # If wee're on a deeper layer, and if a prior layer is zero, then abort
                    if min(
                            np.min(np.abs(x))
                            for x in known_T.get_hidden_layers(point)) < 1e-4:
                        logger.log("is on prior", level=Logger.INFO)
                        continue

                # print("Stepsize", stepsize)
                tmp = Tracker().query_count
                solution = do_better_sweep(offset=point,
                                           low=-stepsize,
                                           high=stepsize,
                                           known_T=known_T)
                # print("qs", query_count-tmp)
                if len(solution) == 0:
                    stepsize *= 1.1
                elif len(solution) > 1:
                    stepsize /= 2
                elif len(solution) == 1:
                    stepsize *= 0.98
                    potential_solution = solution[0]

                    hiddens = extended_T.get_hidden_layers(potential_solution)

                    this_hidden_vec = extended_T.forward(potential_solution)
                    this_hidden = np.min(np.abs(this_hidden_vec))
                    if min(np.min(np.abs(x)) for x in
                           this_hidden_vec) > np.abs(this_hidden) * 0.9:
                        critical_points.append(potential_solution)
                    else:
                        logger.log("Reject it", level=Logger.INFO)
        logger.log("Finished with a total of",
                   len(critical_points),
                   "critical points",
                   level=Logger.INFO)
        return critical_points

    critical_points_list = []
    for _ in range(1):
        NUM = sizes[LAYER] * 2
        critical_points_list.extend(get_more_points(NUM))

        critical_points = np.array(critical_points_list)

        hidden_layer = known_T.forward(np.array(critical_points),
                                       with_relu=True)

        if CHEATING:
            out = np.abs(matmul(hidden_layer, A[LAYER], B[LAYER]))
            which_neuron = int(np.median(which_is_zero(0, [out])))
            logger.log("NEURON NUM", which_neuron, level=Logger.INFO)

            crit_val_0 = out[:, which_neuron]

            logger.log(crit_val_0, level=Logger.INFO)

            # print(list(np.sort(np.abs(crit_val_0))))
            logger.log('probability ok',
                       np.mean(np.abs(crit_val_0) < 1e-8),
                       level=Logger.INFO)

        crit_val_1 = matmul(hidden_layer, known_A[:, row], known_B[row])

        best = (None, 1e6)
        upto = 100

        for iteration in range(upto):
            if iteration % 1000 == 0:
                logger.log("ITERATION",
                           iteration,
                           "OF",
                           upto,
                           level=Logger.INFO)
            if iteration % 2 == 0 or True:

                # Try 1000 times to make sure that we get at least one non-zero per axis
                for _ in range(1000):
                    randn = np.random.choice(len(hidden_layer),
                                             NUM + 2,
                                             replace=False)
                    if np.all(np.any(hidden_layer[randn] != 0, axis=0)):
                        break

                hidden = hidden_layer[randn]
                soln, *rest = np.linalg.lstsq(hidden, np.ones(hidden.shape[0]))

            else:
                randn = np.random.choice(len(hidden_layer),
                                         min(len(hidden_layer),
                                             hidden_layer.shape[1] + 20),
                                         replace=False)
                soln, _ = trim(hidden_layer[randn],
                               np.ones(hidden_layer.shape[0])[randn],
                               hidden_layer.shape[1])

            crit_val_2 = matmul(hidden_layer, soln, None) - 1

            quality = np.median(np.abs(crit_val_2))

            if iteration % 100 == 0:
                logger.log('quality', quality, best[1], level=Logger.INFO)

            if quality < best[1]:
                best = (soln, quality)

            if quality < 1e-10: break
            if quality < 1e-10 and iteration > 1e4: break
            if quality < 1e-8 and iteration > 1e5: break

        soln, _ = best

        if CHEATING:
            logger.log("Compare",
                       np.median(np.abs(crit_val_0)),
                       level=Logger.INFO)
        logger.log("Compare", np.median(np.abs(crit_val_1)), best[1])

        if np.all(np.abs(soln) > 1e-10):
            break

    logger.log('soln', soln, level=Logger.INFO)

    if np.any(np.abs(soln) < 1e-10):
        logger.log("THIS IS BAD. FIX ME NOW.", level=Logger.ERROR)
        exit(1)

    rescale = np.median(soln / known_A[:, row])
    soln[np.abs(soln) < 1e-10] = known_A[:,
                                         row][np.abs(soln) < 1e-10] * rescale

    if CHEATING:
        other = A[LAYER][:, which_neuron]
        logger.log("real / mine / diff", level=Logger.INFO)
        logger.log(other / other[0], level=Logger.INFO)
        logger.log(soln / soln[0], level=Logger.INFO)
        logger.log(known_A[:, row] / known_A[:, row][0], level=Logger.INFO)
        logger.log(other / other[0] - soln / soln[0], level=Logger.INFO)

    if best[1] < np.mean(np.abs(crit_val_1)) or True:
        return soln, -1
    else:
        logger.log("FAILED TO IMPROVE ACCURACY OF ROW", row, level=Logger.INFO)
        logger.log(np.mean(np.abs(crit_val_2)),
                   'vs',
                   np.mean(np.abs(crit_val_1)),
                   level=Logger.INFO)
        return known_A[:, row], known_B[row]
Пример #7
0
def compute_layer_values(critical_points, known_T, LAYER):
    if LAYER == 0:
        COUNT = neuron_count[LAYER+1] * 3
    else:
        COUNT = neuron_count[LAYER+1] * np.log(sizes[LAYER+1]) * 3


    # type: [(ratios, critical_point)]
    this_layer_critical_points = []

    partial_weights = None
    partial_biases = None

    def check_fn(point):
        if partial_weights is None:
            return True
        hidden = matmul(known_T.forward(point, with_relu=True), partial_weights.T, partial_biases)
        if np.any(np.abs(hidden) < 1e-4):
            return False
        
        return True

    
    print()
    print("Start running critical point search to find neurons on layer", LAYER)
    while True:
        print("At this iteration I have", len(this_layer_critical_points), "critical points")

        def reuse_critical_points():
            for witness in critical_points:
                yield witness
        
        this_layer_critical_points.extend(gather_ratios(reuse_critical_points(), known_T, check_fn,
                                                         LAYER, COUNT))

        print("Query count after that search:", query_count)
        print("And now up to ", len(this_layer_critical_points), "critical points")

        ## filter out duplicates
        filtered_points = []

        # Let's not add points that are identical to onees we've already done.
        for i,(ratio1,point1) in enumerate(this_layer_critical_points):
            for ratio2,point2 in this_layer_critical_points[i+1:]:
                if np.sum((point1 - point2)**2)**.5 < 1e-10:
                    break
            else:
                filtered_points.append((ratio1, point1))
        
        this_layer_critical_points = filtered_points

        print("After filtering duplicates we're down to ", len(this_layer_critical_points), "critical points")
        

        print("Start trying to do the graph solving")
        try:
            critical_groups, extracted_normals = graph_solve([x[0] for x in this_layer_critical_points],
                                                             [x[1] for x in this_layer_critical_points],
                                                             neuron_count[LAYER+1],
                                                             LAYER=LAYER,
                                                             debug=True)
            break
        except GatherMoreData as e:
            print("Graph solving failed because we didn't explore all sides of at least one neuron")
            print("Fall back to the hyperplane following algorithm in order to get more data")
            
            def mine(r):
                while len(r) > 0:
                    print("Yielding a point")
                    yield r[0]
                    r = r[1:]
                print("No more to give!")
    
            prev_T = KnownT(known_T.A[:-1], known_T.B[:-1])
            
            _, more_critical_points = sign_recovery.solve_layer_sign(prev_T, known_T.A[-1], known_T.B[-1], mine(e.data),
                                                                     LAYER-1, already_checked_critical_points=True,
                                                                     only_need_positive=True)

            print("Add more", len(more_critical_points))
            this_layer_critical_points.extend(gather_ratios(more_critical_points, known_T, check_fn,
                                                             LAYER, 1e6))
            print("Done adding")
            
            COUNT = neuron_count[LAYER+1]
        except AcceptableFailure as e:
            print("Graph solving failed; get more points")
            COUNT = neuron_count[LAYER+1]
            if 'partial_solution' in dir(e):

                if len(e.partial_solution[0]) > 0:
                    partial_weights, corresponding_examples = e.partial_solution
                    print("Got partial solution with shape", partial_weights.shape)
                    if CHEATING:
                        print("Corresponding to", np.argmin(np.abs(cheat_get_inner_layers([x[0] for x in corresponding_examples])[LAYER]),axis=1))
    
                    partial_biases = []
                    for weight, examples in zip(partial_weights, corresponding_examples):

                        hidden = known_T.forward(examples, with_relu=True)
                        print("hidden", np.array(hidden).shape)
                        bias = -np.median(np.dot(hidden, weight))
                        partial_biases.append(bias)
                    partial_biases = np.array(partial_biases)
                    
                
    print("Number of critical points per cluster", [len(x) for x in critical_groups])
    
    point_per_class = [x[0] for x in critical_groups]

    extracted_normals = np.array(extracted_normals).T

    # Compute the bias because we know wx+b=0
    extracted_bias = [matmul(known_T.forward(point_per_class[i], with_relu=True), extracted_normals[:,i], c=None) for i in range(neuron_count[LAYER+1])]

    # Don't forget to negate it.
    # That's important.
    # No, I definitely didn't forget this line the first time around.
    extracted_bias = -np.array(extracted_bias)

    # For the failed-to-identify neurons, set the bias to zero
    extracted_bias *= np.any(extracted_normals != 0,axis=0)[:,np.newaxis]
        

    if CHEATING:
        # Compute how far we off from the true matrix
        real_scaled = A[LAYER]/A[LAYER][0]
        extracted_scaled = extracted_normals/extracted_normals[0]
        
        mask = []
        reorder_rows = []
        for i in range(len(extracted_bias)):
            which_idx = np.argmin(np.sum(np.abs(real_scaled - extracted_scaled[:,[i]]),axis=0))
            reorder_rows.append(which_idx)
            mask.append((A[LAYER][0,which_idx]))
    
        print('matrix norm difference', np.sum(np.abs(extracted_normals*mask - A[LAYER][:,reorder_rows])))
    else:
        mask = [1]*len(extracted_bias)
    

    return extracted_normals, extracted_bias, mask