n_neurons = args.dim * args.neurons_per_dim
preferred_locations = hilbert_2d(-limit, limit, n_neurons, rng, p=8, N=2, normal_std=3)

encoders_place_cell = np.zeros((n_neurons, args.dim))
encoders_band_cell = np.zeros((n_neurons, args.dim))
encoders_grid_cell = np.zeros((n_neurons, args.dim))
encoders_mixed = np.zeros((n_neurons, args.dim))
mixed_intercepts = []
for n in range(n_neurons):
    ind = rng.randint(0, len(phis))
    encoders_place_cell[n, :] = encode_func(preferred_locations[n, :])

    encoders_grid_cell[n, :] = grid_cell_encoder(
        location=preferred_locations[n, :],
        dim=args.dim, phi=phis[ind], angle=angles[ind],
        toroid_index=ind
    )

    band_ind = rng.randint(0, 3)
    encoders_band_cell[n, :] = band_cell_encoder(
        location=preferred_locations[n, :],
        dim=args.dim, phi=phis[ind], angle=angles[ind],
        toroid_index=ind,
        band_index=band_ind
    )

    mix_ind = rng.randint(0, 3)
    if mix_ind == 0:
        encoders_mixed[n, :] = encoders_place_cell[n, :]
        mixed_intercepts.append(.3)
示例#2
0
# encoder = band_cell_encoder(dim=dim, phase=0, toroid_index=0, band_index=0)

limit = 10  #5
res = 128
xs = np.linspace(-limit, limit, res)
ys = np.linspace(-limit, limit, res)

hmv = get_heatmap_vectors(xs, ys, X, Y)

# encoder = grid_cell_encoder(dim=dim, phases=(np.pi/1., np.pi/1., np.pi/1.), toroid_index=1)
# encoder = grid_cell_encoder(dim=dim, phases=(0, 0, 0), toroid_index=0)
# encoder = band_cell_encoder(dim=dim, phase=0, toroid_index=0, band_index=0)

encoder = grid_cell_encoder(dim=dim,
                            phi=np.pi / 2.,
                            angle=np.pi / 3.,
                            location=(0, 0),
                            toroid_index=0)

plt.figure()
sim = np.tensordot(encoder, hmv, axes=([0], [2]))
plt.imshow(sim)

# encoder = grid_cell_encoder(dim=dim, phases=(np.pi/2., np.pi/2., np.pi/2.), toroid_index=1)
# encoder = grid_cell_encoder(dim=dim, phases=(np.pi/1., np.pi/1., np.pi/1.), toroid_index=1)
# encoder = grid_cell_encoder(dim=dim, phases=(np.pi/1., 0, 0), toroid_index=0)
# encoder = grid_cell_encoder(dim=dim, phases=(np.pi/1., np.pi/1., np.pi/1.), toroid_index=0)
# encoder = grid_cell_encoder(dim=dim, phases=(0, np.pi/1., np.pi/1.), toroid_index=0)
# encoder = grid_cell_encoder(dim=dim, phases=(0, np.pi/1., 0), toroid_index=0)
# encoder = band_cell_encoder(dim=dim, phase=1*np.pi/1., toroid_index=0, band_index=0)