示例#1
0
def syn2post_min(syn_values, post_ids, post_num: int, indices_are_sorted=True):
    """The syn-to-post minimization computation.

  This function is equivalent to:

  .. highlight:: python
  .. code-block:: python

    post_val = np.zeros(post_num)
    for syn_i, post_i in enumerate(post_ids):
      post_val[post_i] = np.minimum(post_val[post_i], syn_values[syn_i])

  Parameters
  ----------
  syn_values: jax.numpy.ndarray, JaxArray, Variable
    The synaptic values.
  post_ids: jax.numpy.ndarray, JaxArray
    The post-synaptic neuron ids. If ``post_ids`` is generated by
    ``brainpy.conn.TwoEndConnector``, then it has sorted indices.
    Otherwise, this function cannot guarantee indices are sorted.
    You's better set ``indices_are_sorted=False``.
  post_num: int
    The number of the post-synaptic neurons.
  indices_are_sorted: whether ``post_ids`` is known to be sorted.

  Returns
  -------
  post_val: jax.numpy.ndarray, JaxArray
    The post-synaptic value.
  """
    post_ids = as_device_array(post_ids)
    syn_values = as_device_array(syn_values)
    if syn_values.dtype == jnp.bool_:
        syn_values = jnp.asarray(syn_values, dtype=jnp.int32)
    return _jit_seg_min(syn_values, post_ids, post_num, indices_are_sorted)
示例#2
0
def syn2post_softmax(syn_values,
                     post_ids,
                     post_num: int,
                     indices_are_sorted=True):
    """The syn-to-post softmax computation.

  Parameters
  ----------
  syn_values: jax.numpy.ndarray, JaxArray, Variable
    The synaptic values.
  post_ids: jax.numpy.ndarray, JaxArray
    The post-synaptic neuron ids. If ``post_ids`` is generated by
    ``brainpy.conn.TwoEndConnector``, then it has sorted indices.
    Otherwise, this function cannot guarantee indices are sorted.
    You's better set ``indices_are_sorted=False``.
  post_num: int
    The number of the post-synaptic neurons.
  indices_are_sorted: whether ``post_ids`` is known to be sorted.

  Returns
  -------
  post_val: jax.numpy.ndarray, JaxArray
    The post-synaptic value.
  """
    post_ids = as_device_array(post_ids)
    syn_values = as_device_array(syn_values)
    if syn_values.dtype == jnp.bool_:
        syn_values = jnp.asarray(syn_values, dtype=jnp.int32)
    syn_maxs = _jit_seg_max(syn_values, post_ids, post_num, indices_are_sorted)
    syn_values = syn_values - syn_maxs[post_ids]
    syn_values = jnp.exp(syn_values)
    normalizers = _jit_seg_sum(syn_values, post_ids, post_num,
                               indices_are_sorted)
    softmax = syn_values / normalizers[post_ids]
    return jnp.nan_to_num(softmax)
示例#3
0
def syn2post_sum(syn_values, post_ids, post_num: int, indices_are_sorted=True):
    """The syn-to-post summation computation.

  This function is equivalent to:

  .. highlight:: python
  .. code-block:: python

    post_val = np.zeros(post_num)
    for syn_i, post_i in enumerate(post_ids):
      post_val[post_i] += syn_values[syn_i]

  Parameters
  ----------
  syn_values: jax.numpy.ndarray, JaxArray, Variable
    The synaptic values.
  post_ids: jax.numpy.ndarray, JaxArray
    The post-synaptic neuron ids.
  post_num: int
    The number of the post-synaptic neurons.

  Returns
  -------
  post_val: jax.numpy.ndarray, JaxArray
    The post-synaptic value.
  """
    post_ids = as_device_array(post_ids)
    syn_values = as_device_array(syn_values)
    if syn_values.dtype == jnp.bool_:
        syn_values = jnp.asarray(syn_values, dtype=jnp.int32)
    return _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted)
示例#4
0
def pre2syn(pre_values, pre_ids):
    """The pre-to-syn computation.

  Change the pre-synaptic data to the data with the dimension of synapses.

  This function is equivalent to:

  .. highlight:: python
  .. code-block:: python

    syn_val = np.zeros(len(pre_ids))
    for syn_i, pre_i in enumerate(pre_ids):
      syn_val[i] = pre_values[pre_i]

  Parameters
  ----------
  pre_values: float, jax.numpy.ndarray, JaxArray, Variable
    The pre-synaptic value.
  pre_ids: jax.numpy.ndarray, JaxArray
    The pre-synaptic neuron index.

  Returns
  -------
  syn_val: jax.numpy.ndarray, JaxArray
    The synaptic value.
  """
    pre_values = as_device_array(pre_values)
    pre_ids = as_device_array(pre_ids)
    if jnp.ndim(pre_values) == 0:
        return jnp.ones(len(pre_ids), dtype=pre_values.dtype) * pre_values
    else:
        return _pre2syn(pre_ids, pre_values)
示例#5
0
def pre2post_mean(pre_values, post_num, post_ids, pre_ids=None):
    """The pre-to-post synaptic mean computation.

  Parameters
  ----------
  pre_values: float, jax.numpy.ndarray, JaxArray, Variable
    The pre-synaptic values.
  pre_ids: jax.numpy.ndarray, JaxArray
    The connected pre-synaptic neuron ids.
  post_ids: jax.numpy.ndarray, JaxArray
    The connected post-synaptic neuron ids.
  post_num: int
    Output dimension. The number of post-synaptic neurons.

  Returns
  -------
  post_val: jax.numpy.ndarray, JaxArray
    The value with the size of post-synaptic neurons.
  """
    out = jnp.zeros(post_num, dtype=profile.float_)
    pre_values = as_device_array(pre_values)
    post_ids = as_device_array(post_ids)
    if jnp.ndim(pre_values) == 0:
        # return out.at[post_ids].set(pre_values)
        return out.at[jnp.unique(post_ids)].set(pre_values)
    else:
        _raise_pre_ids_is_none(pre_ids)
        pre_ids = as_device_array(pre_ids)
        pre_values = pre2syn(pre_values, pre_ids)
        return syn2post_mean(pre_values, post_ids, post_num)
示例#6
0
def pre2post_event_sum(events, pre2post, post_num, values=1.):
    """The pre-to-post synaptic computation with event-driven summation.

  When ``values`` is a scalar, this function is equivalent to

  .. highlight:: python
  .. code-block:: python

    post_val = np.zeros(post_num)
    post_ids, idnptr = pre2post
    for i in range(pre_num):
      if events[i]:
        for j in range(idnptr[i], idnptr[i+1]):
          post_val[post_ids[i]] += values

  When ``values`` is a vector (with the length of ``len(post_ids)``),
  this function is equivalent to

  .. highlight:: python
  .. code-block:: python

    post_val = np.zeros(post_num)

    post_ids, idnptr = pre2post
    for i in range(pre_num):
      if events[i]:
        for j in range(idnptr[i], idnptr[i+1]):
          post_val[post_ids[i]] += values[j]


  Parameters
  ----------
  events: JaxArray, jax.numpy.ndarray, Variable
    The events, must be bool.
  pre2post: tuple of JaxArray, tuple of jax.numpy.ndarray
    A tuple contains the connection information of pre-to-post.
  post_num: int
    The number of post-synaptic group.
  values: float, JaxArray, jax.numpy.ndarray
    The value to make summation.

  Returns
  -------
  out: JaxArray, jax.numpy.ndarray
    A tensor with the shape of ``post_num``.
  """
    _check_brainpylib(pre2post_event_sum.__name__)
    indices, idnptr = pre2post
    events = as_device_array(events)
    indices = as_device_array(indices)
    idnptr = as_device_array(idnptr)
    values = as_device_array(values)
    return brainpylib.event_sum(events, (indices, idnptr), post_num, values)
示例#7
0
def pre2post_max(pre_values, post_num, post_ids, pre_ids=None):
    """The pre-to-post synaptic maximization.

  This function is equivalent to:

  .. highlight:: python
  .. code-block:: python

     post_val = np.zeros(post_num)
     for i, j in zip(pre_ids, post_ids):
       post_val[j] = np.maximum(post_val[j], pre_values[pre_ids[i]])

  Parameters
  ----------
  pre_values: float, jax.numpy.ndarray, JaxArray, Variable
    The pre-synaptic values.
  pre_ids: jax.numpy.ndarray, JaxArray
    The connected pre-synaptic neuron ids.
  post_ids: jax.numpy.ndarray, JaxArray
    The connected post-synaptic neuron ids.
  post_num: int
    Output dimension. The number of post-synaptic neurons.

  Returns
  -------
  post_val: jax.numpy.ndarray, JaxArray
    The value with the size of post-synaptic neurons.
  """
    out = jnp.zeros(post_num, dtype=profile.float_)
    pre_values = as_device_array(pre_values)
    post_ids = as_device_array(post_ids)
    if jnp.ndim(pre_values) != 0:
        _raise_pre_ids_is_none(pre_ids)
        pre_ids = as_device_array(pre_ids)
        pre_values = pre_values[pre_ids]
    return out.at[post_ids].max(pre_values)
示例#8
0
def rfftn(a, s=None, axes=None, norm=None):
    a = as_device_array(a)
    return JaxArray(jax.numpy.fft.rfftn(a=a, s=s, axes=axes, norm=norm))
示例#9
0
def rfft(a, n=None, axis=-1, norm=None):
    a = as_device_array(a)
    return JaxArray(jax.numpy.fft.rfft(a=a, n=n, axis=axis, norm=norm))
示例#10
0
def irfft2(a, s=None, axes=(-2, -1), norm=None):
    a = as_device_array(a)
    return JaxArray(jax.numpy.fft.irfft2(a=a, s=s, axes=axes, norm=norm))
示例#11
0
 def _cond_fun(op):
     dyn_values, static_values = op
     for v, d in zip(dyn_vars, dyn_values):
         v.value = d
     return as_device_array(cond_fun(static_values))