def test_homo_values(self): bp.math.random.seed(1345) size = 200 conn = bp.conn.FixedProb(prob=0.5, seed=123) # conn = bp.conn.All2All() conn(pre_size=size, post_size=size) post_ids, indptr = conn.require('pre2post') sps = bm.random.random(size).value < 0.5 # print(sps) value = 3.0233 a = event_sum(sps, (post_ids.value, indptr.value), size, value) print(a)
def pre2post_event_sum(events, pre2post, post_num, values=1.): """The pre-to-post synaptic computation with event-driven summation. When ``values`` is a scalar, this function is equivalent to .. highlight:: python .. code-block:: python post_val = np.zeros(post_num) post_ids, idnptr = pre2post for i in range(pre_num): if events[i]: for j in range(idnptr[i], idnptr[i+1]): post_val[post_ids[i]] += values When ``values`` is a vector (with the length of ``len(post_ids)``), this function is equivalent to .. highlight:: python .. code-block:: python post_val = np.zeros(post_num) post_ids, idnptr = pre2post for i in range(pre_num): if events[i]: for j in range(idnptr[i], idnptr[i+1]): post_val[post_ids[i]] += values[j] Parameters ---------- events: JaxArray, jax.numpy.ndarray, Variable The events, must be bool. pre2post: tuple of JaxArray, tuple of jax.numpy.ndarray A tuple contains the connection information of pre-to-post. post_num: int The number of post-synaptic group. values: float, JaxArray, jax.numpy.ndarray The value to make summation. Returns ------- out: JaxArray, jax.numpy.ndarray A tensor with the shape of ``post_num``. """ _check_brainpylib(pre2post_event_sum.__name__) indices, idnptr = pre2post events = as_device_array(events) indices = as_device_array(indices) idnptr = as_device_array(idnptr) values = as_device_array(values) return brainpylib.event_sum(events, (indices, idnptr), post_num, values)
def test_heter_value(self): bp.math.random.seed(3) size = 200 conn = bp.conn.FixedProb(prob=0.5, seed=3) # conn = bp.conn.One2One() conn(pre_size=size, post_size=size) post_ids, indptr = conn.require('pre2post') # sps = bm.random.randint(0, 2, size).value < 1 sps = bm.random.random(size).value < 0.5 values = bm.random.rand(post_ids.size) # values = bm.ones(post_ids.size) a = event_sum(sps, (post_ids.value, indptr.value), size, values.value) print(a)