コード例 #1
0
def LIF_hard_reset_fptt(x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float, v_reset: float, reciprocal_tau: float):
    '''
    * :ref:`API in English <LIF_hard_reset_fptt-en>`

    .. _LIF_hard_reset_fptt-cn:

    :param reciprocal_tau: :math:`\\frac{1}{\\tau}`
    :type reciprocal_tau: float

    其余的参数参见 :ref:`hard_reset_fptt_template <hard_reset_fptt_template-cn>`。

    :ref:`LIF_hard_reset_forward <LIF_hard_reset_forward-cn>` 的多步版本。

    * :ref:`中文API <LIF_hard_reset_fptt-cn>`

    .. _LIF_hard_reset_fptt-en:

    :param reciprocal_tau: :math:`\\frac{1}{\\tau}`
    :type reciprocal_tau: float

    See :ref:`hard_reset_fptt_template <hard_reset_fptt_template-en>` for more details about other args。

    The multi-step version of :ref:`LIF_hard_reset_forward <LIF_hard_reset_forward-en>`.
    '''
    return _C_neuron.LIF_hard_reset_fptt(x_seq, v, v_threshold, v_reset, reciprocal_tau)
コード例 #2
0
 def forward(self, dv_seq: torch.Tensor):
     if self.v_reset is None:
         raise NotImplementedError
     else:
         if not isinstance(self.v, torch.Tensor):
             self.v = torch.zeros_like(dv_seq[0].data)
             if self.v_reset != 0.0:
                 self.v.fill_(self.v_reset)
         if self.training:
             spike_seq, self.v = LIFMultiStep.apply(dv_seq, self.v, self.v_threshold, self.v_reset, self.alpha, self.detach_reset, self.grad_surrogate_function_index, self.reciprocal_tau)
         else:
             spike_seq, self.v = _C_neuron.LIF_hard_reset_fptt(dv_seq, self.v, self.v_threshold, self.v_reset, self.reciprocal_tau)
         return spike_seq
コード例 #3
0
 def forward(ctx, x_seq, v, v_threshold, v_reset, alpha, detach_reset,
             grad_surrogate_function_index, tau):
     if v_reset is None:
         raise NotImplementedError
     h_seq, spike_seq, v_next = cext_neuron.LIF_hard_reset_fptt(
         x_seq, v, v_threshold, v_reset, tau)
     if x_seq.requires_grad:
         ctx.save_for_backward(h_seq, spike_seq)
         ctx.v_threshold = v_threshold
         ctx.v_reset = v_reset
         ctx.alpha = alpha
         ctx.detach_reset = detach_reset
         ctx.grad_surrogate_function_index = grad_surrogate_function_index
         ctx.tau = tau
     return h_seq, spike_seq, v_next