def single_neuron_time(time, dt, scale, step_size, I_function, shape, spike_trace, additive_spike_trace, tau_s, trace_scale, is_inhibitory, learning, R, C, delta_t, threshold_rh, threshold_r=-55): I = I_function(time, step_size, scale) neuron = ELIFPopulation(shape, spike_trace, additive_spike_trace, tau_s, trace_scale, is_inhibitory, learning, R, C, delta_t, threshold_rh, threshold_r) neuron.dt = dt monitor = Monitor(neuron, state_variables=["s", "u"]) monitor.set_time_steps(time, dt) monitor.reset_state_variables() for i in range(len(I)): neuron.forward(I[i][0]) monitor.record() return neuron, I, torch.transpose(monitor.get("s") * 1, 0, 1), monitor.get("u")
monitor1 = Monitor(n1, state_variables=["s"]) monitor1.set_time_steps(time, dt) monitor1.reset_state_variables() monitor2 = Monitor(n2, state_variables=["s"]) monitor2.set_time_steps(time, dt) monitor2.reset_state_variables() I_ex = 0 for i in range(len(I)): n1.forward(I[i]) n2.forward(I_ex) I_ex = con1.compute() monitor1.record() monitor2.record() s1 = monitor1.get("s").flatten(start_dim=1) s2 = monitor2.get("s").flatten(start_dim=1) I = I.flatten(start_dim=1) d1 = encoder.decode(s1.numpy(), shape1) d2 = encoder.decode(s2.numpy(), shape2) plot = plotting() plot.plot_visual_activity(s2.T, d1, d2) plot.show()
monitor1.reset_state_variables() monitor2 = Monitor(pn2, state_variables=["s", "u"]) monitor2.set_time_steps(time, dt) monitor2.reset_state_variables() I_in = 0 I_self = 0 I_ex = 0 for i in range(len(I1)): pn1.forward(I1[i] + I_self - I_in) pn2.forward(I2[i] + I_ex) I_self = con1.compute() I_in = con2.compute() I_ex = con3.compute() monitor1.record() monitor2.record() s1 = torch.transpose(monitor1.get("s") * 1, 0, 1) s2 = torch.transpose(monitor2.get("s") * 1, 0, 1) plot = plotting() plot.plot_population_activity_init(time / scale) plot.plot_population_activity_update(s1, I1, mode="p1") plot.plot_population_activity_update(s2, I2, start_idx=int(s1.numel() / s1.shape[-1]), mode="p2") plot.show()
out_ep2_ep2 = 0 for i in range(time): ep1.forward(I_ep1[i] - out_ip1_ep1 + out_ep1_ep1) ep2.forward(I_ep2[i] - out_ip1_ep2 + out_ep2_ep2) ip1.forward(I_ip1[i] + out_ep1_ip1 + out_ep2_ip1) out_ep1_ip1 = con_ep1_ip1.compute() out_ep2_ip1 = con_ep2_ip1.compute() out_ip1_ep1 = con_ip1_ep1.compute() out_ip1_ep2 = con_ip1_ep2.compute() out_ep1_ep1 = con_ep1_ep1.compute() out_ep2_ep2 = con_ep2_ep2.compute() monitor_ep1.record() monitor_ep2.record() monitor_ip1.record() s_ep1 = torch.transpose(monitor_ep1.get("s") * 1, 0, 1) s_ep2 = torch.transpose(monitor_ep2.get("s") * 1, 0, 1) s_ip1 = torch.transpose(monitor_ip1.get("s") * 1, 0, 1) plot = plotting() plot.plot_three_population_activity_init(time / scale) plot.plot_three_population_activity_update(s_ep1, I_ep1, s_ep2, I_ep2, s_ip1, I_ip1, n1="ep1", n2="ep2", n3="ip1")