def forward(ctx, spike, filter, Ts): device = spike.device dtype = spike.dtype psp = slayerCuda.conv(spike.contiguous(), filter, Ts) Ts = torch.autograd.Variable(torch.tensor(Ts, device=device, dtype=dtype), requires_grad=False) ctx.save_for_backward(filter, Ts) return psp
def backward(ctx, gradOutput): (output, delay, Ts) = ctx.saved_tensors diffFilter = torch.tensor([-1, 1], dtype=gradOutput.dtype).to(gradOutput.device) / Ts outputDiff = slayerCuda.conv(output, diffFilter, 1) # the conv operation should not be scaled by Ts. # As such, the output is -( x[k+1]/Ts - x[k]/Ts ) which is what we want. gradDelay = torch.sum(gradOutput * outputDiff, [0, -1], keepdim=True).reshape(gradOutput.shape[1:-1]) * Ts # no minus needed here, as it is included in diffFilter which is -1 * [1, -1] return slayerCuda.shift(gradOutput, -delay, Ts), gradDelay, None
def test(self, epoch=0, evalLoss=True, slidingWindow=None, breakIter = None): ''' Testing assistant fucntion. Arguments: * ``epoch``: training epoch number. * ``evalLoss``: a flag to enable or disable loss evalutaion. Default: ``True``. * ``slidingWindow``: the length of sliding window to use for continuous output prediction over time. ``None`` means total spike count is used to produce one output per sample. If it is not ``None``, ``evalLoss`` is overwritten to ``False``. Default: ``None``. * ``breakIter``: number of samples to wait before breaking out of the testing loop. ``None`` means go over the complete training samples. Default: ``None``. ''' if slidingWindow is not None: filter = torch.ones((slidingWindow)).to(self.device) evalLoss = False tSt = datetime.now() for i, (input, target, label) in enumerate(self.testLoader, 0): self.net.eval() with torch.no_grad(): input = input.to(self.device) target = target.to(self.device) count = 0 if self.module.countLog is True: output, count = self.net.forward(input) else: output = self.net.forward(input) if slidingWindow is None: if self.stats is not None: self.stats.testing.correctSamples += torch.sum( predict.getClass(output) == label ).data.item() self.stats.testing.numSamples += len(label) else: filteredOutput = slayerCuda.conv(output.contiguous(), filter, 1)[..., slidingWindow:] predictions = torch.argmax(filteredOutput.reshape(-1, filteredOutput.shape[-1]), dim=0) # print(output.shape, predictions.shape) # print(predictions[:100]) # print(label) # print(torch.sum(predictions == label).item()) # print(torch.sum(predictions == label).item() / predictions.shape[0]) # assert False, 'Just braking' if self.stats is not None: self.stats.testing.correctSamples += torch.sum(predictions == label.to(self.device)).item() self.stats.testing.numSamples += predictions.shape[0] if evalLoss is True: loss = self.error(output, target, label) if self.stats is not None: self.stats.testing.lossSum += loss.cpu().data.item() * (1 if self.lossScale is None else self.lossScale) else: if self.stats is not None: if slidingWindow is None: self.stats.testing.lossSum += (1 if self.lossScale is None else self.lossScale) else: self.stats.testing.lossSum += predictions.shape[0] * (1 if self.lossScale is None else self.lossScale) if self.stats is not None and epoch%self.printInterval == 0: headerList = ['[{}/{} ({:.0f}%)]'.format(i*len(input), len(self.testLoader.dataset), 100.0*i/len(self.testLoader))] if self.module.countLog is True: headerList.append('Spike count: ' + ', '.join(['{}'.format(int(c)) for c in torch.sum(count, dim=0).tolist()])) if self.showTimeSteps is True: headerList.append('nTimeBins: {}'.format(input.shape[-1])) self.stats.print( epoch, i, (datetime.now() - tSt).total_seconds() / (i+1) / input.shape[0], header= headerList, ) if breakIter is not None and i >= breakIter: break
uOut = fc1(gradLog.apply(pspDelayed)) spikeOut = slayer.spike(gradLog.apply(uOut)) # loss error = snn.loss(netParams).to(device) loss = error.spikeTime(spikeOut, spikeDes) loss.backward() # Custom calculation of delay gradient deltaRec = gradLog.data[0] errorRec = gradLog.data[1] # filter to differentiate sinal in time dimension diffFilter = torch.tensor([1, -1], dtype=torch.float).to(device)/Ts # psp derivative signal dpspDelayed_dt = slayerCuda.conv(pspDelayed, diffFilter, 1) # delay graident integration (According to the formula) delayGrad = -torch.sum(errorRec * dpspDelayed_dt, [0, -1], keepdim=True).reshape((Nin, 1, 1)) * Ts class TestAutoGrad1(unittest.TestCase): def test(self): # print('CustomDelayGradient - autoGrad1:', torch.norm(delayGrad - delay.delay.grad).item()) # self.assertEqual(torch.norm(delayGrad - delay.delay.grad).item(), 0, 'CustomDelayGradient and AutoGrad1 results must match.') self.assertTrue(torch.norm(delayGrad - delay.delay.grad).item() < 1e-4, 'CustomDelayGradient and AutoGrad1 results must match.') # AutoGrad 2: # first delay followed by Psp opeartion # reset previous gradient delay.delay.grad = None