def compute_sparse_weights(origin, post, transform, fan_in, noise=0.1, num_samples=100): encoder = post.encoders radius = post.radii[0] if hasattr(transform, 'tolist'): transform = transform.tolist() approx = origin.node.getDecodingApproximator('AXON') # create X matrix X = approx.evalPoints X = MU.transpose([f.multiMap(X) for f in origin.functions]) # create A matrix A = approx.values S = fan_in N_A = len(A) samples = len(A[0]) N_B = len(encoder) w_sparse = np.zeros((N_B, N_A), 'f') noise_sd = MU.max(A) * noise decoder_list = [None for _ in range(num_samples)] for i in range(num_samples): indices = random.sample(range(N_A), S) activity = [A[j] for j in indices] n = [[random.gauss(0, noise_sd) for _ in range(samples)] for j in range(S)] activity = MU.sum(activity, n) activityT = MU.transpose(activity) gamma = MU.prod(activity, activityT) upsilon = MU.prod(activity, X) gamma_inv = pinv(gamma, noise_sd * noise_sd) decoder_list[i] = MU.prod([[x for x in row] for row in gamma_inv], upsilon) for i in range(N_B): ww = MU.prod(random.choice(decoder_list), MU.prod(MU.transpose(transform), encoder[i])) for j, k in enumerate(indices): w_sparse[i, k] = float(ww[j]) / radius return list(w_sparse)