def loss(x, r): hidden = known_T.forward(x, with_relu=True, np=jnp) dotted = matmul(hidden, jnp.array(known_A)[:, r], jnp.array(known_B)[r], np=jnp) return jnp.sum(jnp.square(dotted))
def check_fn(point): if partial_weights is None: return True hidden = matmul(known_T.forward(point, with_relu=True), partial_weights.T, partial_biases) if np.any(np.abs(hidden) < 1e-4): return False return True
def is_solution(input_tuple): signs, (known_A0, known_B0, LAYER, known_hidden_so_far, K, responses) = input_tuple new_signs = np.array( [-1 if x == '0' else 1 for x in bin((1 << K) + signs)[3:]]) if CHEATING: if signs % 1001 == 0: print('tick', signs) else: if signs % 100001 == 0: # This isn't cheating, but makes things prettier print('tick', signs) guess_A0 = known_A0 * new_signs guess_B0 = known_B0 * new_signs # We're going to set up a system of equations here # The matrix is going to have a bunch of rows (equal to number of equations) # each row is of the form # [h_0 h_1 h_2 h_3 h_4 ... h_n 1] # where h_n is the hidden vector after multiplying by the guessed matrix. # and 1 is the weight for the bias term inputs = matmul(known_hidden_so_far, guess_A0, guess_B0) inputs[inputs < 0] = 0 if responses is None: responses = np.ones((inputs.shape[0], 1)) else: inputs = np.concatenate([inputs, np.ones((inputs.shape[0], 1))], axis=1) pass solution, res, _, _ = scipy.linalg.lstsq(inputs, responses) bias = np.dot(inputs, solution) - responses res = np.std(bias) #print("Recovered vector", solution.flatten()) if res > 1e-2: return (res, new_signs, solution), 0 bias = bias.mean(axis=0) #solution = np.concatenate([solution, [-bias]])[:, np.newaxis] mat = (solution / solution[0][0])[:-1, :] if np.any(np.isnan(mat)) or np.any(np.isinf(mat)): print("Invalid solution") return (res, new_signs, solution), 0 else: s = solution / solution[0][0] s[np.abs(s) < 1e-14] = 0 return (res, new_signs, solution), 1
def choose_new_direction_from_minimize(previous_axis): """ Given the current point which is at a critical point of the next layer neuron, compute which direction we should travel to continue with finding more points on this hyperplane. Our goal is going to be to pick a direction that lets us explore a new part of the space we haven't seen before. """ print("Choose a new direction to travel in") if len(history) == 0: which_to_change = 0 new_perp_dir = perp_dir new_start_point = start_point initial_signs = get_polytope_at(known_T, known_A, known_B, start_point) # If we're in the 1 region of the polytope then we try to make it smaller # otherwise make it bigger fn = min if initial_signs[0] == 1 else max else: neuron_values = np.array([x[1] for x in history]) neuron_positive_count = np.sum(neuron_values > 1, axis=0) neuron_negative_count = np.sum(neuron_values < -1, axis=0) mean_plus_neuron_value = neuron_positive_count / ( neuron_positive_count + neuron_negative_count + 1) mean_minus_neuron_value = neuron_negative_count / ( neuron_positive_count + neuron_negative_count + 1) # we want to find values that are consistently 0 or 1 # So map 0 -> 0 and 1 -> 0 and the middle to higher values if only_need_positive: neuron_consistency = mean_plus_neuron_value else: neuron_consistency = mean_plus_neuron_value * mean_minus_neuron_value # Print out how much progress we've made. # This estimate is probably worse than Windows 95's estimated time remaining. # At least it's monotonic. Be thankful for that. print("Progress", "%.1f" % int(np.mean(neuron_consistency != 0) * 100) + "%") print("Counts on each side of each neuron") print(neuron_positive_count) print(neuron_negative_count) # Choose the smallest value, which is the most consistent which_to_change = np.argmin(neuron_consistency) print("Try to explore the other side of neuron", which_to_change) if which_to_change != previous_axis: if previous_axis is not None and neuron_consistency[ previous_axis] == neuron_consistency[which_to_change]: # If the previous thing we were working towards has the same value as this one # the don't change our mind and just keep going at that one # (almost always--sometimes we can get stuck, let us get unstuck) which_to_change = previous_axis new_start_point = start_point new_perp_dir = perp_dir else: valid_axes = np.where( neuron_consistency == neuron_consistency[which_to_change])[0] best = (np.inf, None, None) for _, potential_hidden_vector, potential_point in history[ -1:]: for potential_axis in valid_axes: value = potential_hidden_vector[potential_axis] if np.abs(value) < best[0]: best = (np.abs(value), potential_axis, potential_point) _, which_to_change, new_start_point = best new_perp_dir = perp_dir else: new_start_point = start_point new_perp_dir = perp_dir # If we're in the 1 region of the polytope then we try to make it smaller # otherwise make it bigger fn = min if neuron_positive_count[ which_to_change] > neuron_negative_count[ which_to_change] else max arg_fn = np.argmin if neuron_positive_count[ which_to_change] > neuron_negative_count[ which_to_change] else np.argmax print("Changing", which_to_change, 'to flip sides because mean is', mean_plus_neuron_value[which_to_change]) val = matmul(known_T.forward(new_start_point, with_relu=True), known_A, known_B)[which_to_change] initial_signs = get_polytope_at(known_T, known_A, known_B, new_start_point) # Now we're going to figure out what direction makes this biggest/smallest # this doesn't take any queries # There's probably an analytical way to do this. # But thinking is hard. Just try 1000 random angles. # There are no queries involved in this process. choices = [] for _ in range(1000): random_dir = np.random.normal(size=DIM) perp_component = np.dot(random_dir, new_perp_dir) / (np.dot( new_perp_dir, new_perp_dir)) * new_perp_dir parallel_dir = random_dir - perp_component # This is the direction we're going to travel in. go_direction = parallel_dir / np.sum(parallel_dir**2)**.5 try: a_bit_further, high = binary_search_towards( known_T, known_A, known_B, new_start_point, initial_signs, go_direction) except AcceptableFailure: continue if a_bit_further is None: continue # choose a direction that makes the Kth value go down by the most val = matmul( known_T.forward(a_bit_further[np.newaxis, :], with_relu=True), known_A, known_B)[0][which_to_change] #print('\t', val, high) choices.append([val, new_start_point + high * go_direction]) best_value, multiple_intersection_point = fn(choices, key=lambda x: x[0]) print('Value', best_value) return new_start_point, multiple_intersection_point, which_to_change
def follow_hyperplane(LAYER, start_point, known_T, known_A, known_B, history=[], MAX_POINTS=1e3, only_need_positive=False): """ This is the ugly algorithm that will let us recover sign for expansive networks. Assumes we have extracted up to layer K-1 correctly, and layer K up to sign. start_point is a neuron on layer K+1 known_T is the transformation that computes up to layer K-1, with known_A and known_B being the layer K matrix up to sign. We're going to come up with a bunch of different inputs, each of which has the same critical point held constant at zero. """ def choose_new_direction_from_minimize(previous_axis): """ Given the current point which is at a critical point of the next layer neuron, compute which direction we should travel to continue with finding more points on this hyperplane. Our goal is going to be to pick a direction that lets us explore a new part of the space we haven't seen before. """ print("Choose a new direction to travel in") if len(history) == 0: which_to_change = 0 new_perp_dir = perp_dir new_start_point = start_point initial_signs = get_polytope_at(known_T, known_A, known_B, start_point) # If we're in the 1 region of the polytope then we try to make it smaller # otherwise make it bigger fn = min if initial_signs[0] == 1 else max else: neuron_values = np.array([x[1] for x in history]) neuron_positive_count = np.sum(neuron_values > 1, axis=0) neuron_negative_count = np.sum(neuron_values < -1, axis=0) mean_plus_neuron_value = neuron_positive_count / ( neuron_positive_count + neuron_negative_count + 1) mean_minus_neuron_value = neuron_negative_count / ( neuron_positive_count + neuron_negative_count + 1) # we want to find values that are consistently 0 or 1 # So map 0 -> 0 and 1 -> 0 and the middle to higher values if only_need_positive: neuron_consistency = mean_plus_neuron_value else: neuron_consistency = mean_plus_neuron_value * mean_minus_neuron_value # Print out how much progress we've made. # This estimate is probably worse than Windows 95's estimated time remaining. # At least it's monotonic. Be thankful for that. print("Progress", "%.1f" % int(np.mean(neuron_consistency != 0) * 100) + "%") print("Counts on each side of each neuron") print(neuron_positive_count) print(neuron_negative_count) # Choose the smallest value, which is the most consistent which_to_change = np.argmin(neuron_consistency) print("Try to explore the other side of neuron", which_to_change) if which_to_change != previous_axis: if previous_axis is not None and neuron_consistency[ previous_axis] == neuron_consistency[which_to_change]: # If the previous thing we were working towards has the same value as this one # the don't change our mind and just keep going at that one # (almost always--sometimes we can get stuck, let us get unstuck) which_to_change = previous_axis new_start_point = start_point new_perp_dir = perp_dir else: valid_axes = np.where( neuron_consistency == neuron_consistency[which_to_change])[0] best = (np.inf, None, None) for _, potential_hidden_vector, potential_point in history[ -1:]: for potential_axis in valid_axes: value = potential_hidden_vector[potential_axis] if np.abs(value) < best[0]: best = (np.abs(value), potential_axis, potential_point) _, which_to_change, new_start_point = best new_perp_dir = perp_dir else: new_start_point = start_point new_perp_dir = perp_dir # If we're in the 1 region of the polytope then we try to make it smaller # otherwise make it bigger fn = min if neuron_positive_count[ which_to_change] > neuron_negative_count[ which_to_change] else max arg_fn = np.argmin if neuron_positive_count[ which_to_change] > neuron_negative_count[ which_to_change] else np.argmax print("Changing", which_to_change, 'to flip sides because mean is', mean_plus_neuron_value[which_to_change]) val = matmul(known_T.forward(new_start_point, with_relu=True), known_A, known_B)[which_to_change] initial_signs = get_polytope_at(known_T, known_A, known_B, new_start_point) # Now we're going to figure out what direction makes this biggest/smallest # this doesn't take any queries # There's probably an analytical way to do this. # But thinking is hard. Just try 1000 random angles. # There are no queries involved in this process. choices = [] for _ in range(1000): random_dir = np.random.normal(size=DIM) perp_component = np.dot(random_dir, new_perp_dir) / (np.dot( new_perp_dir, new_perp_dir)) * new_perp_dir parallel_dir = random_dir - perp_component # This is the direction we're going to travel in. go_direction = parallel_dir / np.sum(parallel_dir**2)**.5 try: a_bit_further, high = binary_search_towards( known_T, known_A, known_B, new_start_point, initial_signs, go_direction) except AcceptableFailure: continue if a_bit_further is None: continue # choose a direction that makes the Kth value go down by the most val = matmul( known_T.forward(a_bit_further[np.newaxis, :], with_relu=True), known_A, known_B)[0][which_to_change] #print('\t', val, high) choices.append([val, new_start_point + high * go_direction]) best_value, multiple_intersection_point = fn(choices, key=lambda x: x[0]) print('Value', best_value) return new_start_point, multiple_intersection_point, which_to_change ################################################### ### Actual code to do the sign recovery starts. ### ################################################### start_box_step = 0 points_on_plane = [] if CHEATING: layer = np.abs( cheat_get_inner_layers(np.array(start_point))[LAYER + 1]) print("Layer", layer) which_is_zero = np.argmin(layer) current_change_axis = 0 while True: print("\n\n") print("-----" * 10) if CHEATING: layer = np.abs( cheat_get_inner_layers(np.array(start_point))[LAYER + 1]) #print('layer',LAYER+1, layer) #print('all inner layers') #for e in cheat_get_inner_layers(np.array(start_point)): # print(e) which_is_zero_2 = np.argmin(np.abs(layer)) if which_is_zero_2 != which_is_zero: print("STARTED WITH", which_is_zero, "NOW IS", which_is_zero_2) print(layer) raise # Keep track of where we've been, so we can go to new places. which_polytope = get_polytope_at(known_T, known_A, known_B, start_point, False) # [-1 1 -1] hidden_vector = get_hidden_at(known_T, known_A, known_B, LAYER, start_point, False) sign_at_init = sign_to_int(which_polytope) # 0b010 -> 2 print("Number of collected points", len(points_on_plane)) if len(points_on_plane) > MAX_POINTS: return points_on_plane, False neuron_values = np.array([x[1] for x in history]) neuron_positive_count = np.sum(neuron_values > 1, axis=0) neuron_negative_count = np.sum(neuron_values < -1, axis=0) if (np.all(neuron_positive_count > 0) and np.all(neuron_negative_count > 0)) or \ (only_need_positive and np.all(neuron_positive_count > 0)): print("Have all the points we need (1)") print(query_count) print(neuron_positive_count) print(neuron_negative_count) neuron_values = np.array([ get_hidden_at(known_T, known_A, known_B, LAYER, x, False) for x in points_on_plane ]) neuron_positive_count = np.sum(neuron_values > 1, axis=0) neuron_negative_count = np.sum(neuron_values < -1, axis=0) print(neuron_positive_count) print(neuron_negative_count) return points_on_plane, True # 1. find a way to move along the hyperplane by computing the normal # direction using the ratios function. Then find a parallel direction. try: #perp_dir = get_ratios([start_point], [range(DIM)], eps=1e-4)[0].flatten() perp_dir = get_ratios_lstsq(0, [start_point], [range(DIM)], KnownT([], []), eps=1e-5)[0].flatten() except AcceptableFailure: print( "Failed to compute ratio at start point. Something very bad happened." ) return points_on_plane, False # Record these points. history.append((which_polytope, hidden_vector, np.copy(start_point))) # We can't just pick any parallel direction. If we did, then we would # not end up covering much of the input space. # Instead, we're going to figure out which layer-1 hyperplanes are "visible" # from the current point. Then we're going to try and go reach all of them. # This is the point at which the first and second layers intersect. start_point, multiple_intersection_point, new_change_axis = choose_new_direction_from_minimize( current_change_axis) if new_change_axis != current_change_axis: start_point, multiple_intersection_point, current_change_axis = choose_new_direction_from_minimize( None) #if CHEATING: # print("INIT MULTIPLE", cheat_get_inner_layers(multiple_intersection_point)) # Refine the direction we're going to travel in---stay numerically stable. towards_multiple_direction = multiple_intersection_point - start_point step_distance = np.sum(towards_multiple_direction**2)**.5 print("Distance we need to step:", step_distance) if step_distance > 1 or True: mid_point = 1e-4 * towards_multiple_direction / np.sum( towards_multiple_direction**2)**.5 + start_point random_dir = np.random.normal(size=DIM) mid_points = do_better_sweep(mid_point, perp_dir / np.sum(perp_dir**2)**.5, low=-1e-3, high=1e-3, known_T=known_T) if len(mid_points) > 0: mid_point = mid_points[np.argmin( np.sum((mid_point - mid_points)**2, axis=1))] towards_multiple_direction = mid_point - start_point towards_multiple_direction = towards_multiple_direction / np.sum( towards_multiple_direction**2)**.5 initial_signs = get_polytope_at(known_T, known_A, known_B, start_point) _, high = binary_search_towards(known_T, known_A, known_B, start_point, initial_signs, towards_multiple_direction) multiple_intersection_point = towards_multiple_direction * high + start_point # Find the angle of the next hyperplane # First, take random steps away from the intersection point # Then run the search algorithm to find some intersections # what we find will either be a layer-1 or layer-2 intersection. print("Now try to find the continuation direction") success = None while success is None: if start_box_step < 0: start_box_step = 0 print("VERY BAD FAILURE") print("Choose a new random point to start from") which_point = np.random.randint(0, len(history)) start_point = history[which_point][2] print("New point is", which_point) current_change_axis = np.random.randint(0, sizes[LAYER + 1]) print("New axis to change", current_change_axis) break print("\tStart the box step with size", start_box_step) try: success, camefrom, stepsize = find_plane_angle( known_T, known_A, known_B, multiple_intersection_point, sign_at_init, start_box_step) except AcceptableFailure: # Go back to the top and try with a new start point print("\tOkay we need to try with a new start point") start_box_step = -10 start_box_step -= 2 if success is None: continue val = matmul( known_T.forward(multiple_intersection_point, with_relu=True), known_A, known_B)[new_change_axis] print("Value at multiple:", val) val = matmul(known_T.forward(success, with_relu=True), known_A, known_B)[new_change_axis] print("Value at success:", val) if stepsize < 10: new_move_direction = success - multiple_intersection_point # We don't want to be right next to the multiple intersection point. # So let's binary search to find how far away we can go while remaining in this polytope. # Then we'll go half as far as we can maximally go. initial_signs = get_polytope_at(known_T, known_A, known_B, success) print("polytope at initial", sign_to_int(initial_signs)) low = 0 high = 1 while high - low > 1e-2: mid = (high + low) / 2 query_point = multiple_intersection_point + mid * new_move_direction next_signs = get_polytope_at(known_T, known_A, known_B, query_point) print( "polytope at", mid, sign_to_int(next_signs), "%x" % (sign_to_int(next_signs) ^ sign_to_int(initial_signs))) if initial_signs == next_signs: low = mid else: high = mid print("GO TO", mid) success = multiple_intersection_point + (mid / 2) * new_move_direction val = matmul(known_T.forward(success, with_relu=True), known_A, known_B)[new_change_axis] print("Value at moved success:", val) print("Adding the points to the set of known good points") points_on_plane.append(start_point) if camefrom is not None: points_on_plane.append(camefrom) #print("Old start point", start_point) #print("Set to success", success) start_point = success start_box_step = max(stepsize - 1, 0) return points_on_plane, False
def improve_row_precision(args): """ Improve the precision of an extracted row. We think we know where it is, but let's actually figure it out for sure. To do this, start by sampling a bunch of points near where we expect the line to be. This gives us a picture like this X X X X X X Where some are correct and some are wrong. With some robust statistics, try to fit a line that fits through most of the points (in high dimension!) X / X / X X / / X This solves the equation and improves the point for us. """ (LAYER, known_T, known_A, known_B, row, did_again) = args logger.log("Improve the extracted neuron number", row, level=Logger.INFO) logger.log(np.sum(np.abs(known_A[:, row])), level=Logger.INFO) if np.sum(np.abs(known_A[:, row])) < 1e-8: return known_A[:, row], known_B[row] def loss(x, r): hidden = known_T.forward(x, with_relu=True, np=jnp) dotted = matmul(hidden, jnp.array(known_A)[:, r], jnp.array(known_B)[r], np=jnp) return jnp.sum(jnp.square(dotted)) loss_grad = jax.jit(jax.grad(loss)) loss = jax.jit(loss) extended_T = known_T.extend_by(known_A, known_B) def get_more_points(NUM): """ Gather more points. This procedure is really kind of ugly and should probably be fixed. We want to find points that are near where we expect them to be. So begin by finding preimages to points that are on the line with gradient descent. This should be completely possible, because we have d_0 input dimensions but only want to control one inner layer. """ logger.log("Gather some more actual critical points on the plane", level=Logger.INFO) stepsize = .1 critical_points = [] while len(critical_points) <= NUM: logger.log("On this iteration I have ", len(critical_points), "critical points on the plane", level=Logger.INFO) points = np.random.normal(0, 1e3, size=( 100, DIM, )) lr = 10 for step in range(5000): # Use JaX's built in optimizer to do this. # We want to adjust the LR so that we get a better solution # as we optimize. Probably there is a better way to do this, # but this seems to work just fine. # No queries involvd here. if step % 1000 == 0: lr *= .5 init, opt_update, get_params = jax.experimental.optimizers.adam( lr) @jax.jit def update(i, opt_state, batch): params = get_params(opt_state) return opt_update(i, loss_grad(batch, row), opt_state) opt_state = init(points) if step % 100 == 0: ell = loss(points, row) if CHEATING: # This isn't cheating, but makes things prettier print(ell) if ell < 1e-5: break opt_state = update(step, opt_state, points) points = opt_state.packed_state[0][0] for point in points: # For each point, try to see where it actually is. # First, if optimization failed, then abort. if loss(point, row) > 1e-5: continue if LAYER > 0: # If wee're on a deeper layer, and if a prior layer is zero, then abort if min( np.min(np.abs(x)) for x in known_T.get_hidden_layers(point)) < 1e-4: logger.log("is on prior", level=Logger.INFO) continue # print("Stepsize", stepsize) tmp = Tracker().query_count solution = do_better_sweep(offset=point, low=-stepsize, high=stepsize, known_T=known_T) # print("qs", query_count-tmp) if len(solution) == 0: stepsize *= 1.1 elif len(solution) > 1: stepsize /= 2 elif len(solution) == 1: stepsize *= 0.98 potential_solution = solution[0] hiddens = extended_T.get_hidden_layers(potential_solution) this_hidden_vec = extended_T.forward(potential_solution) this_hidden = np.min(np.abs(this_hidden_vec)) if min(np.min(np.abs(x)) for x in this_hidden_vec) > np.abs(this_hidden) * 0.9: critical_points.append(potential_solution) else: logger.log("Reject it", level=Logger.INFO) logger.log("Finished with a total of", len(critical_points), "critical points", level=Logger.INFO) return critical_points critical_points_list = [] for _ in range(1): NUM = sizes[LAYER] * 2 critical_points_list.extend(get_more_points(NUM)) critical_points = np.array(critical_points_list) hidden_layer = known_T.forward(np.array(critical_points), with_relu=True) if CHEATING: out = np.abs(matmul(hidden_layer, A[LAYER], B[LAYER])) which_neuron = int(np.median(which_is_zero(0, [out]))) logger.log("NEURON NUM", which_neuron, level=Logger.INFO) crit_val_0 = out[:, which_neuron] logger.log(crit_val_0, level=Logger.INFO) # print(list(np.sort(np.abs(crit_val_0)))) logger.log('probability ok', np.mean(np.abs(crit_val_0) < 1e-8), level=Logger.INFO) crit_val_1 = matmul(hidden_layer, known_A[:, row], known_B[row]) best = (None, 1e6) upto = 100 for iteration in range(upto): if iteration % 1000 == 0: logger.log("ITERATION", iteration, "OF", upto, level=Logger.INFO) if iteration % 2 == 0 or True: # Try 1000 times to make sure that we get at least one non-zero per axis for _ in range(1000): randn = np.random.choice(len(hidden_layer), NUM + 2, replace=False) if np.all(np.any(hidden_layer[randn] != 0, axis=0)): break hidden = hidden_layer[randn] soln, *rest = np.linalg.lstsq(hidden, np.ones(hidden.shape[0])) else: randn = np.random.choice(len(hidden_layer), min(len(hidden_layer), hidden_layer.shape[1] + 20), replace=False) soln, _ = trim(hidden_layer[randn], np.ones(hidden_layer.shape[0])[randn], hidden_layer.shape[1]) crit_val_2 = matmul(hidden_layer, soln, None) - 1 quality = np.median(np.abs(crit_val_2)) if iteration % 100 == 0: logger.log('quality', quality, best[1], level=Logger.INFO) if quality < best[1]: best = (soln, quality) if quality < 1e-10: break if quality < 1e-10 and iteration > 1e4: break if quality < 1e-8 and iteration > 1e5: break soln, _ = best if CHEATING: logger.log("Compare", np.median(np.abs(crit_val_0)), level=Logger.INFO) logger.log("Compare", np.median(np.abs(crit_val_1)), best[1]) if np.all(np.abs(soln) > 1e-10): break logger.log('soln', soln, level=Logger.INFO) if np.any(np.abs(soln) < 1e-10): logger.log("THIS IS BAD. FIX ME NOW.", level=Logger.ERROR) exit(1) rescale = np.median(soln / known_A[:, row]) soln[np.abs(soln) < 1e-10] = known_A[:, row][np.abs(soln) < 1e-10] * rescale if CHEATING: other = A[LAYER][:, which_neuron] logger.log("real / mine / diff", level=Logger.INFO) logger.log(other / other[0], level=Logger.INFO) logger.log(soln / soln[0], level=Logger.INFO) logger.log(known_A[:, row] / known_A[:, row][0], level=Logger.INFO) logger.log(other / other[0] - soln / soln[0], level=Logger.INFO) if best[1] < np.mean(np.abs(crit_val_1)) or True: return soln, -1 else: logger.log("FAILED TO IMPROVE ACCURACY OF ROW", row, level=Logger.INFO) logger.log(np.mean(np.abs(crit_val_2)), 'vs', np.mean(np.abs(crit_val_1)), level=Logger.INFO) return known_A[:, row], known_B[row]
def compute_layer_values(critical_points, known_T, LAYER): if LAYER == 0: COUNT = neuron_count[LAYER+1] * 3 else: COUNT = neuron_count[LAYER+1] * np.log(sizes[LAYER+1]) * 3 # type: [(ratios, critical_point)] this_layer_critical_points = [] partial_weights = None partial_biases = None def check_fn(point): if partial_weights is None: return True hidden = matmul(known_T.forward(point, with_relu=True), partial_weights.T, partial_biases) if np.any(np.abs(hidden) < 1e-4): return False return True print() print("Start running critical point search to find neurons on layer", LAYER) while True: print("At this iteration I have", len(this_layer_critical_points), "critical points") def reuse_critical_points(): for witness in critical_points: yield witness this_layer_critical_points.extend(gather_ratios(reuse_critical_points(), known_T, check_fn, LAYER, COUNT)) print("Query count after that search:", query_count) print("And now up to ", len(this_layer_critical_points), "critical points") ## filter out duplicates filtered_points = [] # Let's not add points that are identical to onees we've already done. for i,(ratio1,point1) in enumerate(this_layer_critical_points): for ratio2,point2 in this_layer_critical_points[i+1:]: if np.sum((point1 - point2)**2)**.5 < 1e-10: break else: filtered_points.append((ratio1, point1)) this_layer_critical_points = filtered_points print("After filtering duplicates we're down to ", len(this_layer_critical_points), "critical points") print("Start trying to do the graph solving") try: critical_groups, extracted_normals = graph_solve([x[0] for x in this_layer_critical_points], [x[1] for x in this_layer_critical_points], neuron_count[LAYER+1], LAYER=LAYER, debug=True) break except GatherMoreData as e: print("Graph solving failed because we didn't explore all sides of at least one neuron") print("Fall back to the hyperplane following algorithm in order to get more data") def mine(r): while len(r) > 0: print("Yielding a point") yield r[0] r = r[1:] print("No more to give!") prev_T = KnownT(known_T.A[:-1], known_T.B[:-1]) _, more_critical_points = sign_recovery.solve_layer_sign(prev_T, known_T.A[-1], known_T.B[-1], mine(e.data), LAYER-1, already_checked_critical_points=True, only_need_positive=True) print("Add more", len(more_critical_points)) this_layer_critical_points.extend(gather_ratios(more_critical_points, known_T, check_fn, LAYER, 1e6)) print("Done adding") COUNT = neuron_count[LAYER+1] except AcceptableFailure as e: print("Graph solving failed; get more points") COUNT = neuron_count[LAYER+1] if 'partial_solution' in dir(e): if len(e.partial_solution[0]) > 0: partial_weights, corresponding_examples = e.partial_solution print("Got partial solution with shape", partial_weights.shape) if CHEATING: print("Corresponding to", np.argmin(np.abs(cheat_get_inner_layers([x[0] for x in corresponding_examples])[LAYER]),axis=1)) partial_biases = [] for weight, examples in zip(partial_weights, corresponding_examples): hidden = known_T.forward(examples, with_relu=True) print("hidden", np.array(hidden).shape) bias = -np.median(np.dot(hidden, weight)) partial_biases.append(bias) partial_biases = np.array(partial_biases) print("Number of critical points per cluster", [len(x) for x in critical_groups]) point_per_class = [x[0] for x in critical_groups] extracted_normals = np.array(extracted_normals).T # Compute the bias because we know wx+b=0 extracted_bias = [matmul(known_T.forward(point_per_class[i], with_relu=True), extracted_normals[:,i], c=None) for i in range(neuron_count[LAYER+1])] # Don't forget to negate it. # That's important. # No, I definitely didn't forget this line the first time around. extracted_bias = -np.array(extracted_bias) # For the failed-to-identify neurons, set the bias to zero extracted_bias *= np.any(extracted_normals != 0,axis=0)[:,np.newaxis] if CHEATING: # Compute how far we off from the true matrix real_scaled = A[LAYER]/A[LAYER][0] extracted_scaled = extracted_normals/extracted_normals[0] mask = [] reorder_rows = [] for i in range(len(extracted_bias)): which_idx = np.argmin(np.sum(np.abs(real_scaled - extracted_scaled[:,[i]]),axis=0)) reorder_rows.append(which_idx) mask.append((A[LAYER][0,which_idx])) print('matrix norm difference', np.sum(np.abs(extracted_normals*mask - A[LAYER][:,reorder_rows]))) else: mask = [1]*len(extracted_bias) return extracted_normals, extracted_bias, mask