def test_spike_latency_encode_without_batch_3(): spikes = torch.as_tensor([ [ [1.0, 1.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0, 1.0], [1.0, 0.0, 1.0, 0.0, 0.0], [1.0, 0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0, 0.0], ], [ [1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0, 1.0], ], ]) actual = spike_latency_encode(spikes) expected = spikes.clone() expected[1] = torch.as_tensor([ [0.0, 0.0, 1.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 0.0, 1.0, 0.0], [1.0, 0.0, 0.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0, 0.0], [1.0, 1.0, 1.0, 0.0, 1.0], ]) assert torch.equal(actual, expected)
def test_spike_latency_encode_with_batch(): data = torch.as_tensor([[100, 100], [100, 100]]) spikes = constant_current_lif_encode(data, 5) actual = spike_latency_encode(spikes) expected = torch.zeros((5, 2, 2)) expected[0] = torch.as_tensor([[1, 1], [1, 1]]) assert torch.equal(actual, expected)
def forward(self, input_spikes): return encode.spike_latency_encode(input_spikes)
def test_spike_latency_encode_without_batch_2(): spikes = torch.as_tensor([[[0, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1]]]) actual = spike_latency_encode(spikes) expected = torch.as_tensor([[[0, 1, 1], [1, 1, 1]], [[1, 0, 0], [0, 0, 0]]]) assert torch.equal(actual, expected)
def test_spike_latency_encode_without_batch(): spikes = torch.as_tensor([[0, 1, 1, 0], [1, 1, 1, 0]]) actual = spike_latency_encode(spikes) assert torch.equal(actual, torch.as_tensor([[0, 1, 1, 0], [1, 0, 0, 0]]))