示例#1
0
def test_spike_latency_encode():
    data = torch.as_tensor([[0, 100, 100], [100, 100, 100]])
    encoder = torch.nn.Sequential(ConstantCurrentLIFEncoder(2),
                                  SpikeLatencyEncoder())
    actual = encoder(data)
    expected = torch.zeros((2, 2, 3))
    expected[0] = torch.as_tensor([[0, 1, 1], [1, 1, 1]])
    assert torch.equal(actual, expected)
示例#2
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=1,
                 dilation=1,
                 groups=1,
                 bias=True,
                 seq_length=100):
        super(ConvLSTMCellSpike, self).__init__()
        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.padding_h = tuple(
            k // 2
            for k, s, p, d in zip(kernel_size, stride, padding, dilation))
        self.dilation = dilation
        self.groups = groups
        self.weight_ih = Parameter(
            torch.Tensor(4 * out_channels, in_channels // groups,
                         *kernel_size))
        self.weight_hh = Parameter(
            torch.Tensor(4 * out_channels, out_channels // groups,
                         *kernel_size))
        self.weight_ch = Parameter(
            torch.Tensor(3 * out_channels, out_channels // groups,
                         *kernel_size))
        if bias:
            self.bias_ih = Parameter(torch.Tensor(4 * out_channels))
            self.bias_hh = Parameter(torch.Tensor(4 * out_channels))
            self.bias_ch = Parameter(torch.Tensor(3 * out_channels))
        else:
            self.register_parameter('bias_ih', None)
            self.register_parameter('bias_hh', None)
            self.register_parameter('bias_ch', None)
        self.register_buffer('wc_blank', torch.zeros(1, 1, 1, 1))

        self.constant_current_encoder = ConstantCurrentLIFEncoder(
            seq_length=seq_length)
        self.seq_length = seq_length
        self.lif_parameters = LIFParameters(method="super",
                                            alpha=torch.tensor(100))
        self.lif_t_parameters = LIFParameters(method="tanh",
                                              alpha=torch.tensor(100))

        self.reset_parameters()
示例#3
0
def test_spike_latency_encode_backward():
    torch.autograd.set_detect_anomaly(True)
    timesteps = 20
    data = torch.randn(7, 5) + 10
    encoder = torch.nn.Sequential(
        ConstantCurrentLIFEncoder(seq_length=timesteps), SpikeLatencyEncoder())
    layer = LIFRecurrent(5, 2)
    encoded = encoder(data)
    out, _ = layer(encoded)
    out = out.sum()
    out.backward()
示例#4
0
 def __init__(
     self,
     input_features,
     seq_length,
     model="super",
     only_first_spike=False,
 ):
     super(LIFConvNet, self).__init__()
     self.constant_current_encoder = ConstantCurrentLIFEncoder(seq_length=seq_length)
     self.only_first_spike = only_first_spike
     self.input_features = input_features
     self.rsnn = ConvNet4(method=model)
     self.seq_length = seq_length
示例#5
0
    def __init__(self):
        super(Policy, self).__init__()
        self.state_dim = 4
        self.input_features = 16
        self.hidden_features = 128
        self.output_features = 2
        self.constant_current_encoder = ConstantCurrentLIFEncoder(40)
        self.lif = LIFCell(
            2 * self.state_dim,
            self.hidden_features,
            p=LIFParameters(method="super", alpha=100.0),
        )
        self.dropout = torch.nn.Dropout(p=0.5)
        self.readout = LICell(self.hidden_features, self.output_features)

        self.saved_log_probs = []
        self.rewards = []
示例#6
0
    def __init__(self, model="super"):
        super(LSNNPolicy, self).__init__()
        self.state_dim = 4
        self.input_features = 16
        self.hidden_features = 128
        self.output_features = 2
        # self.affine1 = torch.nn.Linear(self.state_dim, self.input_features)
        self.constant_current_encoder = ConstantCurrentLIFEncoder(40)
        self.lif_layer = LSNNCell(
            2 * self.state_dim,
            self.hidden_features,
            p=LSNNParameters(method=model, alpha=100.0),
        )
        self.dropout = torch.nn.Dropout(p=0.5)
        self.readout = LICell(self.hidden_features, self.output_features)

        self.saved_log_probs = []
        self.rewards = []
示例#7
0
def test_spike_latency_encode_max_spikes():
    encoder = torch.nn.Sequential(ConstantCurrentLIFEncoder(seq_length=128),
                                  SpikeLatencyEncoder())
    spikes = encoder(1.1 * torch.ones(10))
    assert torch.sum(spikes).data == 10
示例#8
0
def test_constant_current_lif_encode():
    data = torch.as_tensor([0, 0, 0, 0])
    z = ConstantCurrentLIFEncoder(2).forward(data)
    assert torch.equal(z, torch.zeros((2, 4)))
示例#9
0
def test_spike_latency_encode_chain():
    data = torch.randn(7, 5) + 10
    encoder = torch.nn.Sequential(ConstantCurrentLIFEncoder(2),
                                  SpikeLatencyEncoder())
    encoder(data)