예제 #1
0
    def _generate_sequence(self, idx, rep_idx):
        data_pattern = torch.stack(
            [torch.randperm(2, generator=self.generator) for _ in range(self.steps)]
        ).byte()
        store_index = self.store_indices[idx][rep_idx]
        recall_index = self.recall_indices[idx][rep_idx]
        store_pattern = torch.zeros((self.steps, 1)).byte()
        recall_pattern = store_pattern.clone()
        label_pattern = torch.zeros((self.steps, 2)).byte()

        store_pattern[store_index] = 1
        recall_pattern[recall_index] = 1
        label_class = data_pattern[store_index].byte()
        label_pattern[store_index] = label_class
        label_pattern[recall_index] = label_class
        data_pattern[recall_index] = torch.zeros(2)

        input_pattern = torch.cat((data_pattern, store_pattern, recall_pattern), dim=1)
        input_pattern = input_pattern.repeat_interleave(self.population_size, dim=1)
        encoded = poisson_encode(
            input_pattern,
            seq_length=self.seq_length,
            f_max=self.poisson_rate,
            dt=self.dt,
        )
        encoded = torch.cat(encoded.chunk(self.steps, dim=1)).squeeze()
        return encoded.to(self.device), label_pattern.to(self.device)
예제 #2
0
파일: memory.py 프로젝트: norse/norse
 def encode_pattern(pattern, hz):
     return poisson_encode(
         pattern.repeat_interleave(self.population_size, dim=1),
         seq_length=self.seq_length,
         f_max=hz,
         dt=self.dt,
     )
예제 #3
0
 def forward(self, x):
     return encode.poisson_encode(x,
                                  self.seq_length,
                                  f_max=self.f_max,
                                  dt=self.dt)