예제 #1
0
    def __init__(self,
                 device,
                 num_channels=1,
                 feature_size=28,
                 method="super",
                 dtype=torch.float):
        super(ConvNet4, self).__init__()
        self.features = int(((feature_size - 4) / 2 - 4) / 2)

        self.conv1 = torch.nn.Conv2d(num_channels, 32, 5, 1)
        self.conv2 = torch.nn.Conv2d(32, 64, 5, 1)
        self.fc1 = torch.nn.Linear(self.features * self.features * 64, 1024)
        self.lif0 = LIFFeedForwardCell(
            (32, feature_size - 4, feature_size - 4),
            p=LIFParameters(method=method, alpha=100.0),
        )
        self.lif1 = LIFFeedForwardCell(
            (64, int((feature_size - 4) / 2) - 4, int(
                (feature_size - 4) / 2) - 4),
            p=LIFParameters(method=method, alpha=100.0),
        )
        self.lif2 = LIFFeedForwardCell((1024, ),
                                       p=LIFParameters(method=method,
                                                       alpha=100.0))
        self.out = LICell(1024, 10)
        self.device = device
        self.dtype = dtype
예제 #2
0
파일: conv.py 프로젝트: stjordanis/norse
 def __init__(self,
              device,
              num_channels=1,
              feature_size=28,
              model="super",
              dtype=torch.float):
     super(ConvNet, self).__init__()
     self.features = int(((feature_size - 4) / 2 - 4) / 2)
     self.conv1 = torch.nn.Conv2d(num_channels, 20, 5, 1)
     self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
     self.fc1 = torch.nn.Linear(self.features * self.features * 50, 500)
     self.out = LICell(500, 10)
     self.device = device
     self.lif0 = LIFFeedForwardCell(
         (20, feature_size - 4, feature_size - 4),
         p=LIFParameters(model=model, alpha=100.0),
     )
     self.lif1 = LIFFeedForwardCell(
         (50, int((feature_size - 4) / 2) - 4, int(
             (feature_size - 4) / 2) - 4),
         p=LIFParameters(model=model, alpha=100.0),
     )
     self.lif2 = LIFFeedForwardCell((500, ),
                                    p=LIFParameters(model=model,
                                                    alpha=100.0))
     self.dtype = dtype
예제 #3
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=1,
                 dilation=1,
                 groups=1,
                 bias=True,
                 seq_length=100):
        super(ConvLSTMCellSpike, self).__init__()
        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.padding_h = tuple(
            k // 2
            for k, s, p, d in zip(kernel_size, stride, padding, dilation))
        self.dilation = dilation
        self.groups = groups
        self.weight_ih = Parameter(
            torch.Tensor(4 * out_channels, in_channels // groups,
                         *kernel_size))
        self.weight_hh = Parameter(
            torch.Tensor(4 * out_channels, out_channels // groups,
                         *kernel_size))
        self.weight_ch = Parameter(
            torch.Tensor(3 * out_channels, out_channels // groups,
                         *kernel_size))
        if bias:
            self.bias_ih = Parameter(torch.Tensor(4 * out_channels))
            self.bias_hh = Parameter(torch.Tensor(4 * out_channels))
            self.bias_ch = Parameter(torch.Tensor(3 * out_channels))
        else:
            self.register_parameter('bias_ih', None)
            self.register_parameter('bias_hh', None)
            self.register_parameter('bias_ch', None)
        self.register_buffer('wc_blank', torch.zeros(1, 1, 1, 1))

        self.constant_current_encoder = ConstantCurrentLIFEncoder(
            seq_length=seq_length)
        self.seq_length = seq_length
        self.lif_parameters = LIFParameters(method="super",
                                            alpha=torch.tensor(100))
        self.lif_t_parameters = LIFParameters(method="tanh",
                                              alpha=torch.tensor(100))

        self.reset_parameters()
예제 #4
0
파일: conv.py 프로젝트: weilongzheng/norse
 def __init__(self,
              num_channels=1,
              feature_size=28,
              method="super",
              dtype=torch.float):
     super(ConvNet, self).__init__()
     self.features = int(((feature_size - 4) / 2 - 4) / 2)
     self.conv1 = torch.nn.Conv2d(num_channels, 20, 5, 1)
     self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
     self.fc1 = torch.nn.Linear(self.features * self.features * 50, 500)
     self.out = LILinearCell(500, 10)
     self.lif0 = LIFCell(p=LIFParameters(method=method, alpha=100.0), )
     self.lif1 = LIFCell(p=LIFParameters(method=method, alpha=100.0), )
     self.lif2 = LIFCell(p=LIFParameters(method=method, alpha=100.0))
     self.dtype = dtype
예제 #5
0
def lif_feed_forward_benchmark(parameters: BenchmarkParameters):
    fc = torch.nn.Linear(parameters.features, parameters.features,
                         bias=False).to(parameters.device)
    T = parameters.sequence_length
    s = LIFFeedForwardState(
        v=torch.zeros(parameters.batch_size,
                      parameters.features).to(parameters.device),
        i=torch.zeros(parameters.batch_size,
                      parameters.features).to(parameters.device),
    )
    p = LIFParameters(alpha=100.0, method="heaviside")
    input_spikes = PoissonEncoder(T, dt=parameters.dt)(0.3 * torch.ones(
        parameters.batch_size, parameters.features, device=parameters.device))
    start = time.time()

    spikes = []
    for ts in range(T):
        x = fc(input_spikes[ts, :])
        z, s = lif_feed_forward_step(input_tensor=x,
                                     state=s,
                                     p=p,
                                     dt=parameters.dt)
        spikes += [z]

    spikes = torch.stack(spikes)
    end = time.time()
    duration = end - start
    return duration
예제 #6
0
파일: lif.py 프로젝트: norse/norse
 def __init__(self,
              input_size: int,
              hidden_size: int,
              sparse: bool = False,
              p: LIFParameters = LIFParameters(),
              *args,
              **kwargs):
     self.sparse = sparse
     if sparse:
         super().__init__(
             activation=lif_step_sparse,
             state_fallback=self.initial_state,
             input_size=input_size,
             hidden_size=hidden_size,
             p=p,
             **kwargs,
         )
     else:
         super().__init__(
             activation=lif_step,
             state_fallback=self.initial_state,
             input_size=input_size,
             hidden_size=hidden_size,
             p=p,
             **kwargs,
         )
예제 #7
0
 def __init__(self, p: LIFParameters = LIFParameters(), **kwargs):
     super().__init__(
         activation=lif_feed_forward_step,
         state_fallback=self.initial_state,
         p=p,
         **kwargs,
     )
예제 #8
0
 def __init__(self, p: LIFParameters = LIFParameters(), **kwargs):
     super().__init__(
         lif_feed_forward_step,
         self.initial_state,
         p=p,
         **kwargs,
     )
예제 #9
0
def lif_mc_step(
    input_tensor: torch.Tensor,
    state: LIFState,
    input_weights: torch.Tensor,
    recurrent_weights: torch.Tensor,
    g_coupling: torch.Tensor,
    p: LIFParameters = LIFParameters(),
    dt: float = 0.001,
) -> Tuple[torch.Tensor, LIFState]:
    """Computes a single euler-integration step of a LIF multi-compartment
    neuron-model.

    Parameters:
        input_tensor (torch.Tensor): the input spikes at the current time step
        s (LIFState): current state of the neuron
        input_weights (torch.Tensor): synaptic weights for incoming spikes
        recurrent_weights (torch.Tensor): synaptic weights for recurrent spikes
        g_coupling (torch.Tensor): conductances between the neuron compartments
        p (LIFParameters): neuron parameters
        dt (float): Integration timestep to use
    """
    v_new = state.v + dt * torch.nn.functional.linear(state.v, g_coupling)
    return lif_step(
        input_tensor,
        LIFState(state.z, v_new, state.i),
        input_weights,
        recurrent_weights,
        p,
        dt,
    )
예제 #10
0
파일: test_snn.py 프로젝트: norse/norse
def test_snn_recurrent_cell_weights_autapse_update():
    in_w = torch.ones(3, 2)
    re_w = torch.nn.Parameter(torch.ones(3, 3))
    n = snn.SNNRecurrentCell(
        lif_step,
        lambda x: LIFState(v=torch.zeros(3), i=torch.zeros(3), z=torch.ones(3)
                           ),
        2,
        3,
        p=LIFParameters(v_th=torch.as_tensor(0.1)),
        input_weights=in_w,
        recurrent_weights=re_w,
    )
    assert torch.all(torch.eq(n.recurrent_weights.diag(), torch.zeros(3)))
    optim = torch.optim.Adam(n.parameters())
    optim.zero_grad()
    spikes = []
    s = None
    for _ in range(10):
        z, s = n(torch.ones(2), s)
        spikes.append(z)
    spikes = torch.stack(spikes)
    loss = spikes.sum()
    loss.backward()
    optim.step()
    w = n.recurrent_weights.clone().detach()
    assert not z.sum() == 0.0
    assert torch.all(torch.eq(w.diag(), torch.zeros(3)))
    w.fill_diagonal_(1.0)
    assert not torch.all(torch.eq(w, torch.ones(3, 3)))
예제 #11
0
def test_lift_without_state_with_parameters():
    data = torch.ones(3, 2, 1)
    lifted = lift(lif_feed_forward_step,
                  p=LIFParameters(v_th=torch.as_tensor(0.3), method="tanh"))
    z, s = lifted(data)
    assert z.shape == (3, 2, 1)
    assert s.v.shape == (2, 1)
    assert s.i.shape == (2, 1)
예제 #12
0
class LIFCorrelationParameters(NamedTuple):
    lif_parameters: LIFParameters = LIFParameters()
    input_correlation_parameters: CorrelationSensorParameters = (
        CorrelationSensorParameters()
    )
    recurrent_correlation_parameters: CorrelationSensorParameters = (
        CorrelationSensorParameters()
    )
예제 #13
0
class LIFRefracParameters(NamedTuple):
    """Parameters of a LIF neuron with absolute refractory period.

    Parameters:
        lif (LIFParameters): parameters of the LIF neuron integration
        rho (torch.Tensor): refractory state (count towards zero)
    """

    lif: LIFParameters = LIFParameters()
    rho_reset: torch.Tensor = torch.as_tensor(5.0)
예제 #14
0
def test_lif_heavi():
    x = torch.ones(2, 1)
    s = LIFState(z=torch.ones(2, 1), v=torch.zeros(2, 1), i=torch.zeros(2, 1))
    input_weights = torch.ones(1, 1) * 10
    recurrent_weights = torch.ones(1, 1)
    p = LIFParameters(method="heaviside")
    _, s = lif_step(x, s, input_weights, recurrent_weights, p)
    z, s = lif_step(x, s, input_weights, recurrent_weights, p)
    assert z.max() > 0
    assert z.shape == (2, 1)
예제 #15
0
    def __init__(self,
                 num_channels=1,
                 feature_size=28,
                 method="super",
                 dtype=torch.float):
        super(ConvNet4, self).__init__()
        self.features = int(((feature_size - 4) / 2 - 4) / 2)

        self.conv1 = torch.nn.Conv2d(num_channels, 32, 5, 1)
        self.conv2 = torch.nn.Conv2d(32, 64, 5, 1)
        self.fc1 = torch.nn.Linear(self.features * self.features * 64, 1024)
        self.lif0 = LIFCell(p=LIFParameters(method=method,
                                            alpha=100.0,
                                            v_th=torch.as_tensor(0.7)), )
        self.lif1 = LIFCell(p=LIFParameters(method=method,
                                            alpha=100.0,
                                            v_th=torch.as_tensor(0.7)), )
        self.lif2 = LIFCell(p=LIFParameters(method=method, alpha=100.0))
        self.out = LILinearCell(1024, 10)
        self.dtype = dtype
예제 #16
0
    def __init__(
        self, device, num_channels=1, feature_size=32, method="super", dtype=torch.float
    ):
        super(ConvvNet4, self).__init__()
        self.features = int(((feature_size - 4) / 2 - 4) / 2)

        self.conv1 = torch.nn.Conv2d(1, 6, kernel_size=5, stride=1)
        self.conv2 = torch.nn.Conv2d(6, 16, kernel_size=5,stride=1)
        self.conv3 = torch.nn.Conv2d(16, 120, kernel_size=5, stride=1)
        self.fc1 = torch.nn.Linear(120, 84)
#         self.fc2 = torch.nn.Linear(84, 10)

        self.lif0 = LIFFeedForwardCell(p=LIFParameters(method=method, alpha=100.0))
        self.lif1 = LIFFeedForwardCell(p=LIFParameters(method=method, alpha=100.0))
        self.lif2 = LIFFeedForwardCell(p=LIFParameters(method=method, alpha=100.0))
        self.lif3 = LIFFeedForwardCell(p=LIFParameters(method=method, alpha=100.0))
        self.out = LICell(84, 10)

        self.device = device
        self.dtype = dtype
예제 #17
0
    def __init__(
        self,
        num_channels=1,
        feature_size=32,
        model="super",
        dtype=torch.float,
    ):
        super(Net, self).__init__()
        self.features = int(((feature_size - 4) / 2 - 4) / 2)

        self.conv1 = torch.nn.Conv2d(num_channels, 32, 5, 1)
        self.conv2 = torch.nn.Conv2d(32, 64, 5, 1)
        self.fc1 = torch.nn.Linear(self.features * self.features * 64, 1024)
        self.lif0 = LIFFeedForwardCell(p=LIFParameters(method=model,
                                                       alpha=100.0), )
        self.lif1 = LIFFeedForwardCell(p=LIFParameters(method=model,
                                                       alpha=100.0), )
        self.lif2 = LIFFeedForwardCell(
            p=LIFParameters(method=model, alpha=100.0))
        self.out = LICell(1024, 10)
        self.dtype = dtype
예제 #18
0
def test_lift_with_state_and_parameters():
    data = torch.ones(3, 2, 1)
    lifted = lift(lif_feed_forward_step,
                  p=LIFParameters(v_th=torch.as_tensor(0.3), method="tanh"))
    z, s = lifted(
        data,
        state=LIFFeedForwardState(torch.zeros_like(data[0]),
                                  torch.zeros_like(data[0])),
    )
    assert z.shape == (3, 2, 1)
    assert s.v.shape == (2, 1)
    assert s.i.shape == (2, 1)
예제 #19
0
 def __init__(self,
              input_size: int,
              hidden_size: int,
              p: LIFParameters = LIFParameters(),
              **kwargs):
     super().__init__(
         activation=lif_step,
         state_fallback=self.initial_state,
         p=p,
         input_size=input_size,
         hidden_size=hidden_size,
         **kwargs,
     )
예제 #20
0
 def __init__(self,
              input_size: int,
              hidden_size: int,
              p: LIFParameters = LIFParameters(),
              g_coupling: Optional[torch.Tensor] = None,
              **kwargs):
     super().__init__(activation=None,
                      state_fallback=self.initial_state,
                      input_size=input_size,
                      hidden_size=hidden_size,
                      p=p,
                      **kwargs)
     self.g_coupling = (g_coupling
                        if g_coupling is not None else torch.nn.Parameter(
                            torch.randn(hidden_size, hidden_size) /
                            np.sqrt(hidden_size)))
예제 #21
0
    def __init__(self):
        super(Policy, self).__init__()
        self.state_dim = 4
        self.input_features = 16
        self.hidden_features = 128
        self.output_features = 2
        self.constant_current_encoder = ConstantCurrentLIFEncoder(40)
        self.lif = LIFCell(
            2 * self.state_dim,
            self.hidden_features,
            p=LIFParameters(method="super", alpha=100.0),
        )
        self.dropout = torch.nn.Dropout(p=0.5)
        self.readout = LICell(self.hidden_features, self.output_features)

        self.saved_log_probs = []
        self.rewards = []
예제 #22
0
def lif_mc_feed_forward_step(
    input_tensor: torch.Tensor,
    state: LIFFeedForwardState,
    g_coupling: torch.Tensor,
    p: LIFParameters = LIFParameters(),
    dt: float = 0.001,
) -> Tuple[torch.Tensor, LIFFeedForwardState]:
    """Computes a single euler-integration feed forward step of a LIF
    multi-compartment neuron-model.

    Parameters:
        input_tensor (torch.Tensor): the (weighted) input spikes at the
                              current time step
        s (LIFFeedForwardState): current state of the neuron
        g_coupling (torch.Tensor): conductances between the neuron compartments
        p (LIFParameters): neuron parameters
        dt (float): Integration timestep to use
    """
    v_new = state.v + dt * torch.nn.functional.linear(state.v, g_coupling)
    return lif_feed_forward_step(input_tensor,
                                 LIFFeedForwardState(v_new, state.i), p, dt)
예제 #23
0
파일: test_snn.py 프로젝트: norse/norse
def test_snn_recurrent_weights_autapse_update():
    in_w = torch.ones(3, 2)
    re_w = torch.nn.Parameter(torch.ones(3, 3))
    n = snn.SNNRecurrent(
        lif_step,
        lambda x: LIFState(v=torch.zeros(3), i=torch.zeros(3), z=torch.ones(3)
                           ),
        2,
        3,
        p=LIFParameters(v_th=torch.as_tensor(0.1)),
        input_weights=in_w,
        recurrent_weights=re_w,
    )
    assert torch.all(torch.eq(n.recurrent_weights.diag(), torch.zeros(3)))
    optim = torch.optim.Adam(n.parameters())
    optim.zero_grad()
    z, s = n(torch.ones(1, 2))
    z, _ = n(torch.ones(1, 2), s)
    loss = z.sum()
    loss.backward()
    optim.step()
    w = n.recurrent_weights.clone().detach()
    assert torch.all(torch.eq(w.diag(), torch.zeros(3)))
예제 #24
0
파일: cifar10.py 프로젝트: norse/norse
def main(args):
    # Setup encoding
    num_channels = 3

    # Load datasets
    transform_train = torchvision.transforms.Compose(
        [
            torchvision.transforms.RandomCrop(32, padding=4),
            torchvision.transforms.RandomHorizontalFlip(),
        ]
    )
    transform_test = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
    train_loader = torch.utils.data.DataLoader(
        torchvision.datasets.CIFAR10(
            root=".", train=True, download=True, transform=transform_train
        ),
        batch_size=args.batch_size,
        num_workers=32,
        shuffle=True,
    )
    val_loader = torch.utils.data.DataLoader(
        torchvision.datasets.CIFAR10(root=".", train=False, transform=transform_test),
        batch_size=args.batch_size,
        num_workers=32,
    )

    # Define and train the model
    model = LIFConvNet(
        seq_length=args.seq_length,
        num_channels=num_channels,
        lr=args.lr,
        optimizer=args.optimizer,
        p=LIFParameters(v_th=torch.as_tensor(0.4)),
    )
    trainer = pl.Trainer.from_argparse_args(args)
    trainer.fit(model, train_loader, val_loader)
예제 #25
0
 def forward(self, x):
     lif_parameters = LIFParameters(tau_mem_inv=self.tau_mem_inv, v_th=self.v_th, v_reset=self.v_reset)
     return encode.constant_current_lif_encode(x, self.seq_length, p=lif_parameters, dt=self.dt)