示例#1
0
    def forward(self, x: torch.Tensor, hc=None):
        if hc is None:
            h = torch.zeros(size=[x.shape[0], self.hidden_size],
                            dtype=torch.float,
                            device=x.device)
            c = torch.zeros_like(h)
        else:
            h = hc[0]
            c = hc[1]

        y_ih = torch.split(self.linear_ih(x), self.hidden_size, dim=1)
        y_hh = torch.split(self.linear_hh(h), self.hidden_size, dim=1)
        r = self.surrogate_function1(y_ih[0] + y_hh[0])
        z = self.surrogate_function1(y_ih[1] + y_hh[1])

        if self.surrogate_function2 is None:
            n = self.surrogate_function1(y_ih[2] + r * y_hh[2])
        else:
            assert self.surrogate_function1.spiking == self.surrogate_function2.spiking
            n = self.surrogate_function2(y_ih[2] + r * y_hh[2])

        if self.surrogate_function1.spiking:
            # 可以使用针对脉冲的加速
            h = accelerating.mul(accelerating.sub(torch.ones_like(z.data), z),
                                 n, True) + accelerating.mul(h, z)
            # h不一定是脉冲数据,因此没有使用 accelerating.mul(h, True)
        else:
            h = (1 - z) * n + z * h
        return h
示例#2
0
    def drop(self, batch_size: int):
        mask_w = (torch.rand_like(self.weight.unsqueeze(0).repeat([batch_size] + [1] * self.weight.dim())) > self.p)
        # self.dropped_w = mask_w.to(self.weight) * self.weight  # shape = [batch_size, out_features, in_features]
        self.dropped_w = accelerating.mul(self.weight.unsqueeze(0).repeat(batch_size, 1, 1), mask_w)

        if self.bias is not None:
            mask_b = (torch.rand_like(self.bias.unsqueeze(0).repeat([batch_size] + [1] * self.bias.dim())) > self.p)
            # self.dropped_b = mask_b.to(self.bias) * self.bias
            self.dropped_b = accelerating.mul(self.bias.unsqueeze(0).repeat(batch_size, 1), mask_b)
示例#3
0
 def forward(self, x: torch.Tensor):
     if self.training:
         if self.mask is None:
             self.create_mask(x)
         if self.dropout_spikes:
             return accelerating.mul(self.mask, x)
         else:
             return x * self.mask
     else:
         return x
示例#4
0
def spike_mse_loss(x: torch.Tensor, spikes: torch.Tensor):
    '''
    * :ref:`API in English <spike_mse_loss-en>`

    .. _spike_mse_loss-cn:

    :param x: 任意tensor
    :param spikes: 脉冲tensor。要求 ``spikes`` 中的元素只能为 ``0`` 和 ``1``,或只为 ``False`` 和 ``True``,且 ``spikes.shape`` 必须与 ``x.shape`` 相同
    :return: ``x`` 和 ``spikes`` 逐元素的均方误差(L2范数的平方)

    这个函数与 ``torch.nn.functional.mse_loss()`` 相比,针对脉冲数据进行了加速。其计算按照

    .. math::
        (x - s)^{2} = x^{2} + s^{2} - 2xs = x^{2} + (1 - 2x)s

    .. note::
        由于计算图已经改变,此函数计算出的梯度 :math:`\\frac{\\partial L}{\\partial s}` 与 ``torch.nn.functional.mse_loss()`` 计算出
        的是不一样的。

    * :ref:`中文API <spike_mse_loss-cn>`

    .. _spike_mse_loss-en:

    :param x: an arbitrary tensor
    :param spikes: spiking tensor. The elements in ``spikes`` must be ``0`` and ``1`` or ``False`` and ``True``, and ``spikes.shape`` should be same
        with ``x.shape``
    :return: the mean squared error (squared L2 norm) between each element in ``x`` and ``spikes``

    This function is faster than ``torch.nn.functional.mse_loss()`` for its optimization on spiking data. The compulation
    is carried out as

    .. math::
        (x - s)^{2} = x^{2} + s^{2} - 2xs = x^{2} + (1 - 2x)s

    .. admonition:: Note
        :class: note

        The computation graph of this function is different with the standard MSE. So :math:`\\frac{\\partial L}{\\partial s}`
        compulated by this function is different with that by ``torch.nn.functional.mse_loss()``.

    '''
    return (x.square() + accelerating.mul(1 - 2 * x, spikes)).mean()
示例#5
0
    def forward(self, x: torch.Tensor, hc=None):
        '''
        * :ref:`API in English <SpikingLSTMCell.forward-en>`

        .. _SpikingLSTMCell.forward-cn:

        :param x: ``shape = [batch_size, input_size]`` 的输入
        :type x: torch.Tensor

        :param hc: (h_0, c_0)
                h_0 : torch.Tensor
                    ``shape = [batch_size, hidden_size]``,起始隐藏状态
                c_0 : torch.Tensor
                    ``shape = [batch_size, hidden_size]``,起始细胞状态
                如果不提供(h_0, c_0),``h_0`` 默认 ``c_0`` 默认为0
        :type hc: tuple or None
        :return: (h_1, c_1) :
                h_1 : torch.Tensor
                    ``shape = [batch_size, hidden_size]``,下一个时刻的隐藏状态
                c_1 : torch.Tensor
                    ``shape = [batch_size, hidden_size]``,下一个时刻的细胞状态
        :rtype: tuple

        * :ref:`中文API <SpikingLSTMCell.forward-cn>`

        .. _SpikingLSTMCell.forward-en:

        :param x: the input tensor with ``shape = [batch_size, input_size]``
        :type x: torch.Tensor

        :param hc: (h_0, c_0)
                h_0 : torch.Tensor
                    ``shape = [batch_size, hidden_size]``, tensor containing the initial hidden state for each element in the batch
                c_0 : torch.Tensor
                    ``shape = [batch_size, hidden_size]``, tensor containing the initial cell state for each element in the batch
                If (h_0, c_0) is not provided, both ``h_0`` and ``c_0`` default to zero
        :type hc: tuple or None
        :return: (h_1, c_1) :
                h_1 : torch.Tensor
                    ``shape = [batch_size, hidden_size]``, tensor containing the next hidden state for each element in the batch
                c_1 : torch.Tensor
                    ``shape = [batch_size, hidden_size]``, tensor containing the next cell state for each element in the batch
        :rtype: tuple
        '''
        if hc is None:
            h = torch.zeros(size=[x.shape[0], self.hidden_size],
                            dtype=torch.float,
                            device=x.device)
            c = torch.zeros_like(h)
        else:
            h = hc[0]
            c = hc[1]

        if self.surrogate_function2 is None:
            i, f, g, o = torch.split(self.surrogate_function1(
                self.linear_ih(x) + self.linear_hh(h)),
                                     self.hidden_size,
                                     dim=1)
        else:
            i, f, g, o = torch.split(self.linear_ih(x) + self.linear_hh(h),
                                     self.hidden_size,
                                     dim=1)
            i = self.surrogate_function1(i)
            f = self.surrogate_function1(f)
            g = self.surrogate_function2(g)
            o = self.surrogate_function1(o)

        if self.surrogate_function2 is not None:
            assert self.surrogate_function1.spiking == self.surrogate_function2.spiking

        if self.surrogate_function1.spiking:
            # 可以使用针对脉冲的加速
            # c = f * c + i * g
            c = accelerating.mul(c, f) + accelerating.mul(i, g, True)
            # h = o * c
            h = accelerating.mul(c, o)
        else:
            c = c * f + i * g
            h = c * o

        return h, c