def forward(self, x: torch.Tensor, hc=None): if hc is None: h = torch.zeros(size=[x.shape[0], self.hidden_size], dtype=torch.float, device=x.device) c = torch.zeros_like(h) else: h = hc[0] c = hc[1] y_ih = torch.split(self.linear_ih(x), self.hidden_size, dim=1) y_hh = torch.split(self.linear_hh(h), self.hidden_size, dim=1) r = self.surrogate_function1(y_ih[0] + y_hh[0]) z = self.surrogate_function1(y_ih[1] + y_hh[1]) if self.surrogate_function2 is None: n = self.surrogate_function1(y_ih[2] + r * y_hh[2]) else: assert self.surrogate_function1.spiking == self.surrogate_function2.spiking n = self.surrogate_function2(y_ih[2] + r * y_hh[2]) if self.surrogate_function1.spiking: # 可以使用针对脉冲的加速 h = accelerating.mul(accelerating.sub(torch.ones_like(z.data), z), n, True) + accelerating.mul(h, z) # h不一定是脉冲数据,因此没有使用 accelerating.mul(h, True) else: h = (1 - z) * n + z * h return h
def drop(self, batch_size: int): mask_w = (torch.rand_like(self.weight.unsqueeze(0).repeat([batch_size] + [1] * self.weight.dim())) > self.p) # self.dropped_w = mask_w.to(self.weight) * self.weight # shape = [batch_size, out_features, in_features] self.dropped_w = accelerating.mul(self.weight.unsqueeze(0).repeat(batch_size, 1, 1), mask_w) if self.bias is not None: mask_b = (torch.rand_like(self.bias.unsqueeze(0).repeat([batch_size] + [1] * self.bias.dim())) > self.p) # self.dropped_b = mask_b.to(self.bias) * self.bias self.dropped_b = accelerating.mul(self.bias.unsqueeze(0).repeat(batch_size, 1), mask_b)
def forward(self, x: torch.Tensor): if self.training: if self.mask is None: self.create_mask(x) if self.dropout_spikes: return accelerating.mul(self.mask, x) else: return x * self.mask else: return x
def spike_mse_loss(x: torch.Tensor, spikes: torch.Tensor): ''' * :ref:`API in English <spike_mse_loss-en>` .. _spike_mse_loss-cn: :param x: 任意tensor :param spikes: 脉冲tensor。要求 ``spikes`` 中的元素只能为 ``0`` 和 ``1``,或只为 ``False`` 和 ``True``,且 ``spikes.shape`` 必须与 ``x.shape`` 相同 :return: ``x`` 和 ``spikes`` 逐元素的均方误差(L2范数的平方) 这个函数与 ``torch.nn.functional.mse_loss()`` 相比,针对脉冲数据进行了加速。其计算按照 .. math:: (x - s)^{2} = x^{2} + s^{2} - 2xs = x^{2} + (1 - 2x)s .. note:: 由于计算图已经改变,此函数计算出的梯度 :math:`\\frac{\\partial L}{\\partial s}` 与 ``torch.nn.functional.mse_loss()`` 计算出 的是不一样的。 * :ref:`中文API <spike_mse_loss-cn>` .. _spike_mse_loss-en: :param x: an arbitrary tensor :param spikes: spiking tensor. The elements in ``spikes`` must be ``0`` and ``1`` or ``False`` and ``True``, and ``spikes.shape`` should be same with ``x.shape`` :return: the mean squared error (squared L2 norm) between each element in ``x`` and ``spikes`` This function is faster than ``torch.nn.functional.mse_loss()`` for its optimization on spiking data. The compulation is carried out as .. math:: (x - s)^{2} = x^{2} + s^{2} - 2xs = x^{2} + (1 - 2x)s .. admonition:: Note :class: note The computation graph of this function is different with the standard MSE. So :math:`\\frac{\\partial L}{\\partial s}` compulated by this function is different with that by ``torch.nn.functional.mse_loss()``. ''' return (x.square() + accelerating.mul(1 - 2 * x, spikes)).mean()
def forward(self, x: torch.Tensor, hc=None): ''' * :ref:`API in English <SpikingLSTMCell.forward-en>` .. _SpikingLSTMCell.forward-cn: :param x: ``shape = [batch_size, input_size]`` 的输入 :type x: torch.Tensor :param hc: (h_0, c_0) h_0 : torch.Tensor ``shape = [batch_size, hidden_size]``,起始隐藏状态 c_0 : torch.Tensor ``shape = [batch_size, hidden_size]``,起始细胞状态 如果不提供(h_0, c_0),``h_0`` 默认 ``c_0`` 默认为0 :type hc: tuple or None :return: (h_1, c_1) : h_1 : torch.Tensor ``shape = [batch_size, hidden_size]``,下一个时刻的隐藏状态 c_1 : torch.Tensor ``shape = [batch_size, hidden_size]``,下一个时刻的细胞状态 :rtype: tuple * :ref:`中文API <SpikingLSTMCell.forward-cn>` .. _SpikingLSTMCell.forward-en: :param x: the input tensor with ``shape = [batch_size, input_size]`` :type x: torch.Tensor :param hc: (h_0, c_0) h_0 : torch.Tensor ``shape = [batch_size, hidden_size]``, tensor containing the initial hidden state for each element in the batch c_0 : torch.Tensor ``shape = [batch_size, hidden_size]``, tensor containing the initial cell state for each element in the batch If (h_0, c_0) is not provided, both ``h_0`` and ``c_0`` default to zero :type hc: tuple or None :return: (h_1, c_1) : h_1 : torch.Tensor ``shape = [batch_size, hidden_size]``, tensor containing the next hidden state for each element in the batch c_1 : torch.Tensor ``shape = [batch_size, hidden_size]``, tensor containing the next cell state for each element in the batch :rtype: tuple ''' if hc is None: h = torch.zeros(size=[x.shape[0], self.hidden_size], dtype=torch.float, device=x.device) c = torch.zeros_like(h) else: h = hc[0] c = hc[1] if self.surrogate_function2 is None: i, f, g, o = torch.split(self.surrogate_function1( self.linear_ih(x) + self.linear_hh(h)), self.hidden_size, dim=1) else: i, f, g, o = torch.split(self.linear_ih(x) + self.linear_hh(h), self.hidden_size, dim=1) i = self.surrogate_function1(i) f = self.surrogate_function1(f) g = self.surrogate_function2(g) o = self.surrogate_function1(o) if self.surrogate_function2 is not None: assert self.surrogate_function1.spiking == self.surrogate_function2.spiking if self.surrogate_function1.spiking: # 可以使用针对脉冲的加速 # c = f * c + i * g c = accelerating.mul(c, f) + accelerating.mul(i, g, True) # h = o * c h = accelerating.mul(c, o) else: c = c * f + i * g h = c * o return h, c