def get_error_history(dataset, add_jitter, encoder_dims, max_encoder_rate, max_decoder_rate): ripcoder = RIPCoder(encoder_dims, max_encoder_rate, max_decoder_rate) optimizer = torch.optim.SGD(ripcoder.parameters(), lr=0.01) mse = nn.MSELoss() error_hist = list() for i in range(10000): error = torch.zeros(1, requires_grad=True) for j in range(len(dataset)): if add_jitter: x = jitter(dataset[j]) else: x = dataset[j] # print(x) x_hat, code = ripcoder.forward(x) x_hat_polar = py_STDP.cart_to_polar(x_hat) x_polar = py_STDP.cart_to_polar(x) error = error + mse.forward(x_hat_polar[0], x_polar[0]) if i % 100 == 0: print(' {} - ERROR:{:.5f}'.format(i, error.item())) error_hist.append(error.item()) optimizer.zero_grad() error.backward(retain_graph=True) optimizer.step() return error_hist, ripcoder
def jitter(complex_tensor): tensor_polar = py_STDP.cart_to_polar(complex_tensor) rate = tensor_polar[0] phase = tensor_polar[1] jitter_tensor = py_STDP.polar_to_cart( (rate, phase + torch.rand_like(phase) / 20)) return jitter_tensor
def run_sim(index, arg_dict): N = arg_dict['N'] f_max = arg_dict['f_max'] T = arg_dict['T'] nbins = arg_dict['nbins'] homeostasis = arg_dict['homeostasis'] sparsity = arg_dict['sparsity'] a_param = arg_dict['a_param'] stdp_weight = arg_dict['stdp_weight'] hebb_weight = arg_dict['hebb_weight'] args = inspect.getfullargspec(run_sim) arg_strings = ['{}={}'.format(arg, eval(arg)) for arg in args] desc_string = 'index={},'.format(index) + ','.join(arg_strings) network = py_STDP.RIPLayer(N, N, f_max, homeostasis, sparsity) x_polar = (f_max * torch.ones(1, N), torch.rand(1, N)) hist_vects = list() entropy_traj = list() x = py_STDP.polar_to_cart(x_polar) for i in range(T + 1): if i % 100 == 0: print(' {}'.format(i)) xp = py_STDP.cart_to_polar(x) hist_vect = np.histogram(xp[1].detach(), np.linspace(0, 1, nbins + 1), density=True)[0] p = hist_vect / sum(hist_vect) entropy_traj.append(entropy(p)) hist_vects.append(torch.tensor(hist_vect).unsqueeze(0)) y = network.forward(x) network.rip_learn(x, y, a_param, stdp_weight=stdp_weight, hebb_weight=hebb_weight) x = y dist_hist = torch.cat(hist_vects, dim=0).transpose(0, 1) return dist_hist, entropy_traj, desc_string
def run_sim(N, f_max, T, nbins): network = py_STDP.RIPLayer(N, N, f_max, True) x_polar = (f_max * torch.ones(1, N), torch.rand(1, N)) hist_vects = list() entropy_traj = list() x = py_STDP.polar_to_cart(x_polar) for i in range(T + 1): if i % 100 == 0: print(' {}'.format(i)) xp = py_STDP.cart_to_polar(x) hist_vect = np.histogram(xp[1].detach(), np.linspace(0, 1, nbins + 1), density=True)[0] p = hist_vect / sum(hist_vect) entropy_traj.append(entropy(p)) hist_vects.append(torch.tensor(hist_vect).unsqueeze(0)) y = network.forward(x) network.rip_learn(x, y, 4, stdp_weight=100, hebb_weight=10) x = y dist_hist = torch.cat(hist_vects, dim=0).transpose(0, 1) return dist_hist, entropy_traj
N = 100 T = 10000 nbins = 25 f_max = 1 network = py_STDP.RIPLayer(N, N, f_max, True) x_polar = (f_max * torch.ones(1, N), torch.rand(1, N)) # plt.hist(x_polar[1].detach()) hist_vects_1 = list() entropy_1 = list() x = py_STDP.polar_to_cart(x_polar) for i in range(T + 1): if i % 100 == 0: print(i) xp = py_STDP.cart_to_polar(x) hist_vect = np.histogram(xp[1].detach(), np.linspace(0, 1, nbins + 1), density=True)[0] p = hist_vect / sum(hist_vect) entropy_1.append(entropy(p)) # hist_vect = plt.hist(xp[1].detach(), np.linspace(0, 1, nbins + 1), density=True) hist_vects_1.append(torch.tensor(hist_vect).unsqueeze(0)) x = network.forward(x) hist_hist_1 = torch.cat(hist_vects_1, dim=0).transpose(0, 1) # plt.clf() # plt.subplot(2,1,1) hist_vects_2 = list() entropy_2 = list() x = py_STDP.polar_to_cart(x_polar)