def forward(self,xs): """Perform forward propagation of activations.""" ni,ns,na = self.dims assert len(xs[0])==ni n = len(xs) if n>len(self.gi): raise ocrolib.RecognitionError("input too large for LSTM model") self.last_n = n self.reset(n) for t in range(n): prev = zeros(ns) if t==0 else self.output[t-1] self.source[t,0] = 1 self.source[t,1:1+ni] = xs[t] self.source[t,1+ni:] = prev dot(self.WGI,self.source[t],out=self.gix[t]) dot(self.WGF,self.source[t],out=self.gfx[t]) dot(self.WGO,self.source[t],out=self.gox[t]) dot(self.WCI,self.source[t],out=self.cix[t]) if t>0: # ATTENTION: peep weights are diagonal matrices self.gix[t] += self.WIP*self.state[t-1] self.gfx[t] += self.WFP*self.state[t-1] self.gi[t] = ffunc(self.gix[t]) self.gf[t] = ffunc(self.gfx[t]) self.ci[t] = gfunc(self.cix[t]) self.state[t] = self.ci[t]*self.gi[t] if t>0: self.state[t] += self.gf[t]*self.state[t-1] self.gox[t] += self.WOP*self.state[t] self.go[t] = ffunc(self.gox[t]) self.output[t] = hfunc(self.state[t]) * self.go[t] assert not isnan(self.output[:n]).any() return self.output[:n]
def backward(self,deltas): """Perform backward propagation of deltas.""" n = len(deltas) if n>len(self.gi): raise ocrolib.RecognitionError("input too large") assert n==self.last_n ni,ns,na = self.dims for t in reversed(range(n)): self.outerr[t] = deltas[t] if t<n-1: self.outerr[t] += self.sourceerr[t+1][-ns:] self.goerr[t] = fprime(None,self.go[t]) * hfunc(self.state[t]) * self.outerr[t] self.stateerr[t] = hprime(self.state[t]) * self.go[t] * self.outerr[t] self.stateerr[t] += self.goerr[t]*self.WOP if t<n-1: self.stateerr[t] += self.gferr[t+1]*self.WFP self.stateerr[t] += self.gierr[t+1]*self.WIP self.stateerr[t] += self.stateerr[t+1]*self.gf[t+1] if t>0: self.gferr[t] = fprime(None,self.gf[t])*self.stateerr[t]*self.state[t-1] self.gierr[t] = fprime(None,self.gi[t])*self.stateerr[t]*self.ci[t] # gfunc(self.cix[t]) self.cierr[t] = gprime(None,self.ci[t])*self.stateerr[t]*self.gi[t] dot(self.gierr[t],self.WGI,out=self.sourceerr[t]) if t>0: self.sourceerr[t] += dot(self.gferr[t],self.WGF) self.sourceerr[t] += dot(self.goerr[t],self.WGO) self.sourceerr[t] += dot(self.cierr[t],self.WCI) self.DWIP = nutils.sumprod(self.gierr[1:n],self.state[:n-1],out=self.DWIP) self.DWFP = nutils.sumprod(self.gferr[1:n],self.state[:n-1],out=self.DWFP) self.DWOP = nutils.sumprod(self.goerr[:n],self.state[:n],out=self.DWOP) self.DWGI = nutils.sumouter(self.gierr[:n],self.source[:n],out=self.DWGI) self.DWGF = nutils.sumouter(self.gferr[1:n],self.source[1:n],out=self.DWGF) self.DWGO = nutils.sumouter(self.goerr[:n],self.source[:n],out=self.DWGO) self.DWCI = nutils.sumouter(self.cierr[:n],self.source[:n],out=self.DWCI) return [s[1:1+ni] for s in self.sourceerr[:n]]
def backward(self, deltas): """Perform backward propagation of deltas. Must be called after `forward`. Does not perform weight updating (for that, use the generic `update` method). Returns the `deltas` for the input vectors.""" ni, ns, na = self.dims n = len(deltas) self.last_n = n N = len(self.gi) if n > N: raise ocrolib.RecognitionError("input too large for LSTM model") backward_py(n, N, ni, ns, na, deltas, self.source, self.gix, self.gfx, self.gox, self.cix, self.gi, self.gf, self.go, self.ci, self.state, self.output, self.WGI, self.WGF, self.WGO, self.WCI, self.WIP, self.WFP, self.WOP, self.sourceerr, self.gierr, self.gferr, self.goerr, self.cierr, self.stateerr, self.outerr, self.DWGI, self.DWGF, self.DWGO, self.DWCI, self.DWIP, self.DWFP, self.DWOP) return [s[1:1 + ni] for s in self.sourceerr[:n]]
def forward(self, xs): """Perform forward propagation of activations and update the internal state for a subsequent call to `backward`. Since this performs sequence classification, `xs` is a 2D array, with rows representing input vectors at each time step. Returns a 2D array whose rows represent output vectors for each input vector.""" ni, ns, na = self.dims assert len(xs[0]) == ni n = len(xs) self.last_n = n N = len(self.gi) if n > N: raise ocrolib.RecognitionError("input too large for LSTM model") self.reset(n) forward_py(n, N, ni, ns, na, xs, self.source, self.gix, self.gfx, self.gox, self.cix, self.gi, self.gf, self.go, self.ci, self.state, self.output, self.WGI, self.WGF, self.WGO, self.WCI, self.WIP, self.WFP, self.WOP) assert not isnan(self.output[:n]).any() return self.output[:n]