def do_get_init_info(self, initstates): if issequence(initstates): c_t0 = initstates[0] red = initstates[1:] y_t0 = T.zeros((c_t0.shape[0], self.innerdim)) else: c_t0 = T.zeros((initstates, self.innerdim)) red = initstates y_t0 = T.zeros((initstates, self.innerdim)) return [y_t0, c_t0], red
def do_get_init_info(self, initstates): if issequence(initstates): h_t0 = initstates[0] mem_t0 = initstates[1] red = initstates[2:] m_t0 = T.zeros((h_t0.shape[0], self.innerdim)) else: # initstates is batchsize scalar h_t0 = T.zeros((initstates, self.innerdim)) mem_t0 = T.zeros((initstates, self.memsize, self.innerdim)) red = initstates m_t0 = T.zeros((initstates, self.innerdim)) return [m_t0, mem_t0, h_t0], red
def apply(self, steps): initinfo = self.block.get_init_info(1) seqs = T.zeros((steps, 1, 1)) outputs = T.scan(self.block.rec, sequences=seqs, outputs_info=[None] + initinfo) return outputs[0][:, 0, :]
def _get_seq_emb_t0(self, num, startsymemb=None): # seq_emb = self.embedder(seq[:, 1:]) # (batsize, seqlen-1, embdim) dim = self.embedder.outdim seq_emb_t0_sym = T.zeros( (dim, ), dtype="float32") if startsymemb is None else startsymemb seq_emb_t0 = T.repeat(seq_emb_t0_sym[np.newaxis, :], num, axis=0) return seq_emb_t0
def applymask(cls, xseq, maskseq=None): if maskseq is None: ret = xseq else: mask = T.tensordot(maskseq, T.ones((xseq.shape[2],)), 0) # f32^(batsize, seqlen, outdim) -- maskseq stacked masker = T.concatenate([T.ones((xseq.shape[0], xseq.shape[1], 1)), T.zeros((xseq.shape[0], xseq.shape[1], xseq.shape[2] - 1))], axis=2) # f32^(batsize, seqlen, outdim) -- gives 100% prob to output 0 ret = xseq * mask + masker * (1.0 - mask) return ret
def apply(self, data, weights): # data: (batsize, seqlen, elem_dim) def rec(x_t, att_t, acc): # x_t: (batsize, elem_dim), att_t: (batsize, ), acc: (batsize, elem_dim) acc += T.batched_dot(x_t, att_t) return acc # (batsize, elem_dim) o, _ = T.scan( fn=rec, sequences=[data.dimswap(1, 0), weights.T], outputs_info=T.zeros((data.shape[0], data.shape[2])) ) return o[-1, :, :]
def get_init_info(self, initstates): # either a list of init states or the batsize if not issequence(initstates): initstates = [initstates] * self.numstates acc = [] for initstate in initstates: if isinstance(initstate, int) or initstate.ndim == 0: acc.append(T.zeros((initstate, self.innerdim))) else: acc.append(initstate) return acc
def get_init_info( self, initstates): # either a list of init states or the batsize if not issequence(initstates): initstates = [initstates] * self.numstates acc = [] for initstate in initstates: if isinstance(initstate, int) or initstate.ndim == 0: acc.append(T.zeros((initstate, self.innerdim))) else: acc.append(initstate) return acc
def apply(self, data, weights): # data: (batsize, seqlen, elem_dim) def rec( x_t, att_t, acc ): # x_t: (batsize, elem_dim), att_t: (batsize, ), acc: (batsize, elem_dim) acc += T.batched_dot(x_t, att_t) return acc # (batsize, elem_dim) o, _ = T.scan(fn=rec, sequences=[data.dimswap(1, 0), weights.T], outputs_info=T.zeros((data.shape[0], data.shape[2]))) return o[-1, :, :]
def applymask(cls, xseq, maskseq): if maskseq is None: return xseq else: mask = T.tensordot(maskseq, T.ones((xseq.shape[2],)), 0) # f32^(batsize, seqlen, outdim) -- maskseq stacked masker = T.concatenate( [T.ones((xseq.shape[0], xseq.shape[1], 1)), T.zeros((xseq.shape[0], xseq.shape[1], xseq.shape[2] - 1))], axis=2) # f32^(batsize, seqlen, outdim) -- gives 100% prob to output 0 ret = xseq * mask + masker * (1.0 - mask) return ret
def get_init_info( self, initstates): # either a list of init states or the batsize if not issequence(initstates): initstates = [initstates] * self.numstates acc = [] if self.initstateparams is None: initstateparams = [None] * self.numstates else: initstateparams = self.initstateparams for initstate, initstateparam in zip(initstates, initstateparams): if isinstance(initstate, int) or initstate.ndim == 0: #embed() if initstateparam is not None: toapp = T.repeat(initstateparam.dimadd(0), initstate, axis=0) acc.append(toapp) else: acc.append(T.zeros((initstate, self.innerdim))) else: acc.append(initstate) return acc
def do_get_init_info(self, initstates): # either a list of init states or the batsize if issequence(initstates): return [initstates[0]], initstates[1:] else: return [T.zeros((initstates, self.innerdim))], initstates
def apply(self, x, mask=None): # (batsize, seqlen, dim) and (batsize, seqlen) if mask is None: mask = T.zeros((x.shape[0], x.shape[1])) T.scan(fn=self.rec, sequences=[x.dimswap(1, 0), mask.dimswap(1, 0)])