def syn2post_min(syn_values, post_ids, post_num: int, indices_are_sorted=True): """The syn-to-post minimization computation. This function is equivalent to: .. highlight:: python .. code-block:: python post_val = np.zeros(post_num) for syn_i, post_i in enumerate(post_ids): post_val[post_i] = np.minimum(post_val[post_i], syn_values[syn_i]) Parameters ---------- syn_values: jax.numpy.ndarray, JaxArray, Variable The synaptic values. post_ids: jax.numpy.ndarray, JaxArray The post-synaptic neuron ids. If ``post_ids`` is generated by ``brainpy.conn.TwoEndConnector``, then it has sorted indices. Otherwise, this function cannot guarantee indices are sorted. You's better set ``indices_are_sorted=False``. post_num: int The number of the post-synaptic neurons. indices_are_sorted: whether ``post_ids`` is known to be sorted. Returns ------- post_val: jax.numpy.ndarray, JaxArray The post-synaptic value. """ post_ids = as_device_array(post_ids) syn_values = as_device_array(syn_values) if syn_values.dtype == jnp.bool_: syn_values = jnp.asarray(syn_values, dtype=jnp.int32) return _jit_seg_min(syn_values, post_ids, post_num, indices_are_sorted)
def syn2post_softmax(syn_values, post_ids, post_num: int, indices_are_sorted=True): """The syn-to-post softmax computation. Parameters ---------- syn_values: jax.numpy.ndarray, JaxArray, Variable The synaptic values. post_ids: jax.numpy.ndarray, JaxArray The post-synaptic neuron ids. If ``post_ids`` is generated by ``brainpy.conn.TwoEndConnector``, then it has sorted indices. Otherwise, this function cannot guarantee indices are sorted. You's better set ``indices_are_sorted=False``. post_num: int The number of the post-synaptic neurons. indices_are_sorted: whether ``post_ids`` is known to be sorted. Returns ------- post_val: jax.numpy.ndarray, JaxArray The post-synaptic value. """ post_ids = as_device_array(post_ids) syn_values = as_device_array(syn_values) if syn_values.dtype == jnp.bool_: syn_values = jnp.asarray(syn_values, dtype=jnp.int32) syn_maxs = _jit_seg_max(syn_values, post_ids, post_num, indices_are_sorted) syn_values = syn_values - syn_maxs[post_ids] syn_values = jnp.exp(syn_values) normalizers = _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted) softmax = syn_values / normalizers[post_ids] return jnp.nan_to_num(softmax)
def syn2post_sum(syn_values, post_ids, post_num: int, indices_are_sorted=True): """The syn-to-post summation computation. This function is equivalent to: .. highlight:: python .. code-block:: python post_val = np.zeros(post_num) for syn_i, post_i in enumerate(post_ids): post_val[post_i] += syn_values[syn_i] Parameters ---------- syn_values: jax.numpy.ndarray, JaxArray, Variable The synaptic values. post_ids: jax.numpy.ndarray, JaxArray The post-synaptic neuron ids. post_num: int The number of the post-synaptic neurons. Returns ------- post_val: jax.numpy.ndarray, JaxArray The post-synaptic value. """ post_ids = as_device_array(post_ids) syn_values = as_device_array(syn_values) if syn_values.dtype == jnp.bool_: syn_values = jnp.asarray(syn_values, dtype=jnp.int32) return _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted)
def pre2syn(pre_values, pre_ids): """The pre-to-syn computation. Change the pre-synaptic data to the data with the dimension of synapses. This function is equivalent to: .. highlight:: python .. code-block:: python syn_val = np.zeros(len(pre_ids)) for syn_i, pre_i in enumerate(pre_ids): syn_val[i] = pre_values[pre_i] Parameters ---------- pre_values: float, jax.numpy.ndarray, JaxArray, Variable The pre-synaptic value. pre_ids: jax.numpy.ndarray, JaxArray The pre-synaptic neuron index. Returns ------- syn_val: jax.numpy.ndarray, JaxArray The synaptic value. """ pre_values = as_device_array(pre_values) pre_ids = as_device_array(pre_ids) if jnp.ndim(pre_values) == 0: return jnp.ones(len(pre_ids), dtype=pre_values.dtype) * pre_values else: return _pre2syn(pre_ids, pre_values)
def pre2post_mean(pre_values, post_num, post_ids, pre_ids=None): """The pre-to-post synaptic mean computation. Parameters ---------- pre_values: float, jax.numpy.ndarray, JaxArray, Variable The pre-synaptic values. pre_ids: jax.numpy.ndarray, JaxArray The connected pre-synaptic neuron ids. post_ids: jax.numpy.ndarray, JaxArray The connected post-synaptic neuron ids. post_num: int Output dimension. The number of post-synaptic neurons. Returns ------- post_val: jax.numpy.ndarray, JaxArray The value with the size of post-synaptic neurons. """ out = jnp.zeros(post_num, dtype=profile.float_) pre_values = as_device_array(pre_values) post_ids = as_device_array(post_ids) if jnp.ndim(pre_values) == 0: # return out.at[post_ids].set(pre_values) return out.at[jnp.unique(post_ids)].set(pre_values) else: _raise_pre_ids_is_none(pre_ids) pre_ids = as_device_array(pre_ids) pre_values = pre2syn(pre_values, pre_ids) return syn2post_mean(pre_values, post_ids, post_num)
def pre2post_event_sum(events, pre2post, post_num, values=1.): """The pre-to-post synaptic computation with event-driven summation. When ``values`` is a scalar, this function is equivalent to .. highlight:: python .. code-block:: python post_val = np.zeros(post_num) post_ids, idnptr = pre2post for i in range(pre_num): if events[i]: for j in range(idnptr[i], idnptr[i+1]): post_val[post_ids[i]] += values When ``values`` is a vector (with the length of ``len(post_ids)``), this function is equivalent to .. highlight:: python .. code-block:: python post_val = np.zeros(post_num) post_ids, idnptr = pre2post for i in range(pre_num): if events[i]: for j in range(idnptr[i], idnptr[i+1]): post_val[post_ids[i]] += values[j] Parameters ---------- events: JaxArray, jax.numpy.ndarray, Variable The events, must be bool. pre2post: tuple of JaxArray, tuple of jax.numpy.ndarray A tuple contains the connection information of pre-to-post. post_num: int The number of post-synaptic group. values: float, JaxArray, jax.numpy.ndarray The value to make summation. Returns ------- out: JaxArray, jax.numpy.ndarray A tensor with the shape of ``post_num``. """ _check_brainpylib(pre2post_event_sum.__name__) indices, idnptr = pre2post events = as_device_array(events) indices = as_device_array(indices) idnptr = as_device_array(idnptr) values = as_device_array(values) return brainpylib.event_sum(events, (indices, idnptr), post_num, values)
def pre2post_max(pre_values, post_num, post_ids, pre_ids=None): """The pre-to-post synaptic maximization. This function is equivalent to: .. highlight:: python .. code-block:: python post_val = np.zeros(post_num) for i, j in zip(pre_ids, post_ids): post_val[j] = np.maximum(post_val[j], pre_values[pre_ids[i]]) Parameters ---------- pre_values: float, jax.numpy.ndarray, JaxArray, Variable The pre-synaptic values. pre_ids: jax.numpy.ndarray, JaxArray The connected pre-synaptic neuron ids. post_ids: jax.numpy.ndarray, JaxArray The connected post-synaptic neuron ids. post_num: int Output dimension. The number of post-synaptic neurons. Returns ------- post_val: jax.numpy.ndarray, JaxArray The value with the size of post-synaptic neurons. """ out = jnp.zeros(post_num, dtype=profile.float_) pre_values = as_device_array(pre_values) post_ids = as_device_array(post_ids) if jnp.ndim(pre_values) != 0: _raise_pre_ids_is_none(pre_ids) pre_ids = as_device_array(pre_ids) pre_values = pre_values[pre_ids] return out.at[post_ids].max(pre_values)
def rfftn(a, s=None, axes=None, norm=None): a = as_device_array(a) return JaxArray(jax.numpy.fft.rfftn(a=a, s=s, axes=axes, norm=norm))
def rfft(a, n=None, axis=-1, norm=None): a = as_device_array(a) return JaxArray(jax.numpy.fft.rfft(a=a, n=n, axis=axis, norm=norm))
def irfft2(a, s=None, axes=(-2, -1), norm=None): a = as_device_array(a) return JaxArray(jax.numpy.fft.irfft2(a=a, s=s, axes=axes, norm=norm))
def _cond_fun(op): dyn_values, static_values = op for v, d in zip(dyn_vars, dyn_values): v.value = d return as_device_array(cond_fun(static_values))