def test_spike_latency_encode(): data = torch.as_tensor([[0, 100, 100], [100, 100, 100]]) encoder = torch.nn.Sequential(ConstantCurrentLIFEncoder(2), SpikeLatencyEncoder()) actual = encoder(data) expected = torch.zeros((2, 2, 3)) expected[0] = torch.as_tensor([[0, 1, 1], [1, 1, 1]]) assert torch.equal(actual, expected)
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=1, dilation=1, groups=1, bias=True, seq_length=100): super(ConvLSTMCellSpike, self).__init__() if in_channels % groups != 0: raise ValueError('in_channels must be divisible by groups') if out_channels % groups != 0: raise ValueError('out_channels must be divisible by groups') kernel_size = _pair(kernel_size) stride = _pair(stride) padding = _pair(padding) dilation = _pair(dilation) self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.padding_h = tuple( k // 2 for k, s, p, d in zip(kernel_size, stride, padding, dilation)) self.dilation = dilation self.groups = groups self.weight_ih = Parameter( torch.Tensor(4 * out_channels, in_channels // groups, *kernel_size)) self.weight_hh = Parameter( torch.Tensor(4 * out_channels, out_channels // groups, *kernel_size)) self.weight_ch = Parameter( torch.Tensor(3 * out_channels, out_channels // groups, *kernel_size)) if bias: self.bias_ih = Parameter(torch.Tensor(4 * out_channels)) self.bias_hh = Parameter(torch.Tensor(4 * out_channels)) self.bias_ch = Parameter(torch.Tensor(3 * out_channels)) else: self.register_parameter('bias_ih', None) self.register_parameter('bias_hh', None) self.register_parameter('bias_ch', None) self.register_buffer('wc_blank', torch.zeros(1, 1, 1, 1)) self.constant_current_encoder = ConstantCurrentLIFEncoder( seq_length=seq_length) self.seq_length = seq_length self.lif_parameters = LIFParameters(method="super", alpha=torch.tensor(100)) self.lif_t_parameters = LIFParameters(method="tanh", alpha=torch.tensor(100)) self.reset_parameters()
def test_spike_latency_encode_backward(): torch.autograd.set_detect_anomaly(True) timesteps = 20 data = torch.randn(7, 5) + 10 encoder = torch.nn.Sequential( ConstantCurrentLIFEncoder(seq_length=timesteps), SpikeLatencyEncoder()) layer = LIFRecurrent(5, 2) encoded = encoder(data) out, _ = layer(encoded) out = out.sum() out.backward()
def __init__( self, input_features, seq_length, model="super", only_first_spike=False, ): super(LIFConvNet, self).__init__() self.constant_current_encoder = ConstantCurrentLIFEncoder(seq_length=seq_length) self.only_first_spike = only_first_spike self.input_features = input_features self.rsnn = ConvNet4(method=model) self.seq_length = seq_length
def __init__(self): super(Policy, self).__init__() self.state_dim = 4 self.input_features = 16 self.hidden_features = 128 self.output_features = 2 self.constant_current_encoder = ConstantCurrentLIFEncoder(40) self.lif = LIFCell( 2 * self.state_dim, self.hidden_features, p=LIFParameters(method="super", alpha=100.0), ) self.dropout = torch.nn.Dropout(p=0.5) self.readout = LICell(self.hidden_features, self.output_features) self.saved_log_probs = [] self.rewards = []
def __init__(self, model="super"): super(LSNNPolicy, self).__init__() self.state_dim = 4 self.input_features = 16 self.hidden_features = 128 self.output_features = 2 # self.affine1 = torch.nn.Linear(self.state_dim, self.input_features) self.constant_current_encoder = ConstantCurrentLIFEncoder(40) self.lif_layer = LSNNCell( 2 * self.state_dim, self.hidden_features, p=LSNNParameters(method=model, alpha=100.0), ) self.dropout = torch.nn.Dropout(p=0.5) self.readout = LICell(self.hidden_features, self.output_features) self.saved_log_probs = [] self.rewards = []
def test_spike_latency_encode_max_spikes(): encoder = torch.nn.Sequential(ConstantCurrentLIFEncoder(seq_length=128), SpikeLatencyEncoder()) spikes = encoder(1.1 * torch.ones(10)) assert torch.sum(spikes).data == 10
def test_constant_current_lif_encode(): data = torch.as_tensor([0, 0, 0, 0]) z = ConstantCurrentLIFEncoder(2).forward(data) assert torch.equal(z, torch.zeros((2, 4)))
def test_spike_latency_encode_chain(): data = torch.randn(7, 5) + 10 encoder = torch.nn.Sequential(ConstantCurrentLIFEncoder(2), SpikeLatencyEncoder()) encoder(data)