Exemplo n.º 1
0
 def forward(ctx, spike, filter, Ts):
     device = spike.device
     dtype  = spike.dtype
     psp = slayerCuda.conv(spike.contiguous(), filter, Ts)
     Ts = torch.autograd.Variable(torch.tensor(Ts, device=device, dtype=dtype), requires_grad=False)
     ctx.save_for_backward(filter, Ts)
     return psp
Exemplo n.º 2
0
    def backward(ctx, gradOutput):
        (output, delay, Ts) = ctx.saved_tensors
        diffFilter = torch.tensor([-1, 1], dtype=gradOutput.dtype).to(gradOutput.device) / Ts
        outputDiff = slayerCuda.conv(output, diffFilter, 1)
        # the conv operation should not be scaled by Ts. 
        # As such, the output is -( x[k+1]/Ts - x[k]/Ts ) which is what we want.
        gradDelay  = torch.sum(gradOutput * outputDiff, [0, -1], keepdim=True).reshape(gradOutput.shape[1:-1]) * Ts
        # no minus needed here, as it is included in diffFilter which is -1 * [1, -1]

        return slayerCuda.shift(gradOutput, -delay, Ts), gradDelay, None
Exemplo n.º 3
0
    def test(self, epoch=0, evalLoss=True, slidingWindow=None, breakIter = None):
        '''
        Testing assistant fucntion.

        Arguments:
            * ``epoch``: training epoch number.
            * ``evalLoss``: a flag to enable or disable loss evalutaion. Default: ``True``.
            * ``slidingWindow``: the length of sliding window to use for continuous output prediction over time. 
                ``None`` means total spike count is used to produce one output per sample. If it is not
                ``None``, ``evalLoss`` is overwritten to ``False``. Default: ``None``.
            * ``breakIter``: number of samples to wait before breaking out of the testing loop. 
                ``None`` means go over the complete training samples. Default: ``None``.
        '''
        if slidingWindow is not None:
            filter = torch.ones((slidingWindow)).to(self.device)
            evalLoss = False

        tSt = datetime.now()
        for i, (input, target, label) in enumerate(self.testLoader, 0):
            self.net.eval()

            with torch.no_grad():
                input  = input.to(self.device)
                target = target.to(self.device) 

                count = 0
                if self.module.countLog is True:
                    output, count = self.net.forward(input)
                else:
                    output = self.net.forward(input)

                if slidingWindow is None:
                    if self.stats is not None:
                        self.stats.testing.correctSamples += torch.sum( predict.getClass(output) == label ).data.item()
                        self.stats.testing.numSamples     += len(label)
                else:
                    filteredOutput = slayerCuda.conv(output.contiguous(), filter, 1)[..., slidingWindow:]
                    predictions = torch.argmax(filteredOutput.reshape(-1, filteredOutput.shape[-1]), dim=0)
                    
                    # print(output.shape, predictions.shape)
                    # print(predictions[:100])
                    # print(label)
                    # print(torch.sum(predictions == label).item())
                    # print(torch.sum(predictions == label).item() / predictions.shape[0])

                    # assert False, 'Just braking'
                    
                    if self.stats is not None:
                        self.stats.testing.correctSamples += torch.sum(predictions == label.to(self.device)).item()
                        self.stats.testing.numSamples     += predictions.shape[0]

                if evalLoss is True:
                    loss = self.error(output, target, label)
                    if self.stats is not None:
                        self.stats.testing.lossSum += loss.cpu().data.item() * (1 if self.lossScale is None else self.lossScale)
                else:
                    if self.stats is not None:
                        if slidingWindow is None:
                            self.stats.testing.lossSum += (1 if self.lossScale is None else self.lossScale)
                        else:
                            self.stats.testing.lossSum += predictions.shape[0] * (1 if self.lossScale is None else self.lossScale)

            if self.stats is not None and epoch%self.printInterval == 0:
                headerList = ['[{}/{} ({:.0f}%)]'.format(i*len(input), len(self.testLoader.dataset), 100.0*i/len(self.testLoader))]
                if self.module.countLog is True:
                    headerList.append('Spike count: ' + ', '.join(['{}'.format(int(c)) for c in torch.sum(count, dim=0).tolist()]))
                if self.showTimeSteps is True:
                    headerList.append('nTimeBins: {}'.format(input.shape[-1]))

                self.stats.print(
                    epoch, i, 
                    (datetime.now() - tSt).total_seconds() / (i+1) / input.shape[0],
                    header= headerList,
                )

            if breakIter is not None and i >= breakIter:
                break
Exemplo n.º 4
0
uOut = fc1(gradLog.apply(pspDelayed))
spikeOut = slayer.spike(gradLog.apply(uOut))

# loss
error = snn.loss(netParams).to(device)
loss  = error.spikeTime(spikeOut, spikeDes)

loss.backward()

# Custom calculation of delay gradient
deltaRec = gradLog.data[0]
errorRec = gradLog.data[1]
# filter to differentiate sinal in time dimension
diffFilter = torch.tensor([1, -1], dtype=torch.float).to(device)/Ts
# psp derivative signal
dpspDelayed_dt = slayerCuda.conv(pspDelayed, diffFilter, 1)
# delay graident integration (According to the formula)
delayGrad = -torch.sum(errorRec * dpspDelayed_dt, [0, -1], keepdim=True).reshape((Nin, 1, 1)) * Ts

class TestAutoGrad1(unittest.TestCase):
	def test(self):
		# print('CustomDelayGradient - autoGrad1:', torch.norm(delayGrad - delay.delay.grad).item())
		# self.assertEqual(torch.norm(delayGrad - delay.delay.grad).item(), 0, 'CustomDelayGradient and AutoGrad1 results must match.')
		self.assertTrue(torch.norm(delayGrad - delay.delay.grad).item() < 1e-4, 'CustomDelayGradient and AutoGrad1 results must match.')

# AutoGrad 2:
# first delay followed by Psp opeartion

# reset previous gradient
delay.delay.grad = None