def test_izhikevich_feedforward_layer(spiking_method): layer = Izhikevich(spiking_method) data = torch.randn(10, 5, 4) out, s = layer(data) assert out.shape == (10, 5, 4) for x in s: assert x.shape == (5, 4)
def test_izhikevich_feedforward_layer_backward_iteration(spiking_method): # Tests that gradient variables can be used in subsequent applications model = Izhikevich(spiking_method) data = torch.ones(10, 6) out, s = model(data) out, _ = model(out, s) loss = out.sum() loss.backward()
def __init__(self, spiking_method: IzhikevichSpikingBehavior): super(SNNetwork, self).__init__() self.spiking_method = spiking_method self.l0 = Izhikevich(spiking_method) self.l1 = Izhikevich(spiking_method) self.s0 = self.s1 = None
def test_izhikevich_feedforward_layer_backward(spiking_method): model = Izhikevich(spiking_method) data = torch.ones(10, 12) out, _ = model(data) loss = out.sum() loss.backward()
def test_izhikevich_in_time(spiking_method): layer = Izhikevich(spiking_method) data = torch.randn(10, 5, 2) out, _ = layer(data) assert out.shape == (10, 5, 2)