def shape(self): (Wh, ) = self.limits[1] d = Wh.shape[-1] / 4 if len(self.limits) > 3: return (2, d) x = self.arg x_shape = x.shape shape = (x_shape[-2], d) if len(x_shape) > 2: shape = (x[0], ) + shape return shape lstm = Function.lstm(real=True, integer=None, eval=lstm_recursive, shape=property(shape)) def LSTM(x, *weights): (W, ), (Wh, ), (b, ) = weights n = x.shape[0] t = Symbol.t(integer=True) return Lamda[t:n](Indexed(lstm[W, Wh, b, t](x), 0)) def LSTMCell(x, *weights): (W, ), (Wh, ), (b, ) = weights n = x.shape[0] t = Symbol.t(integer=True) return Lamda[t:n](Indexed(lstm[W, Wh, b, t](x), 1))