class SeqTrans(Block): def __init__(self, embedder, *layers, **kw): super(SeqTrans, self).__init__(**kw) self.enc = SeqEncoder(embedder, *layers) self.enc.all_outputs().maskoption(MaskMode.NONE) def apply(self, x): return self.enc(x)
class Seq2Vec(Block): def __init__(self, inpemb, enclayers, maskid=0, pool=None, **kw): super(Seq2Vec, self).__init__(**kw) self.maskid = maskid self.inpemb = inpemb if not issequence(enclayers): enclayers = [enclayers] self.pool = pool self.enc = SeqEncoder(inpemb, *enclayers).maskoptions(maskid, MaskMode.AUTO) if self.pool is not None: self.enc = self.enc.all_outputs.with_mask def all_outputs(self): self.enc = self.enc.all_outputs() return self def apply(self, x, mask=None, weights=None): if self.pool is not None: ret, mask = self.enc(x, mask=mask, weights=weights) ret = self.pool(ret, mask) else: ret = self.enc(x, mask=mask, weights=weights) return ret
class SimpleSeq2MultiVec(Block): def __init__(self, indim=400, inpembdim=50, inpemb=None, mode="concat", innerdim=100, numouts=1, maskid=0, bidir=False, maskmode=MaskMode.NONE, **kw): super(SimpleSeq2MultiVec, self).__init__(**kw) if inpemb is None: if inpembdim is None: inpemb = IdxToOneHot(indim) inpembdim = indim else: inpemb = VectorEmbed(indim=indim, dim=inpembdim) elif inpemb is False: inpemb = None else: inpembdim = inpemb.outdim if not issequence(innerdim): innerdim = [innerdim] innerdim[-1] += numouts rnn, lastdim = self.makernu(inpembdim, innerdim, bidir=bidir) self.outdim = lastdim * numouts if mode == "concat" else lastdim self.maskid = maskid self.inpemb = inpemb self.numouts = numouts self.mode = mode self.bidir = bidir if not issequence(rnn): rnn = [rnn] self.enc = SeqEncoder(inpemb, *rnn).maskoptions(maskid, maskmode) self.enc.all_outputs() @staticmethod def makernu(inpembdim, innerdim, bidir=False): return MakeRNU.make(inpembdim, innerdim, bidir=bidir) def apply(self, x, mask=None, weights=None): ret = self.enc(x, mask=mask, weights=weights) # (batsize, seqlen, lastdim) outs = [] # apply mask (SeqEncoder should attach mask to outvar if all_outputs() mask = mask if mask is not None else ret.mask if hasattr( ret, "mask") else None if self.bidir: mid = ret.shape[2] / 2 ret1 = ret[:, :, :mid] ret2 = ret[:, :, mid:] ret = ret1 for i in range(self.numouts): selfweights = ret[:, :, i] # (batsize, seqlen) if self.bidir: selfweights += ret2[:, :, i] selfweights = Softmax()(selfweights) if mask is not None: selfweights *= mask # apply mask selfweights = selfweights / T.sum(selfweights, axis=1).dimshuffle( 0, "x") # renormalize weightedstates = ret[:, :, self.numouts:] * selfweights.dimshuffle( 0, 1, "x") if self.bidir: weightedstates2 = ret2[:, :, self.numouts:] * selfweights.dimshuffle( 0, 1, "x") weightedstates = T.concatenate( [weightedstates, weightedstates2], axis=2) out = T.sum(weightedstates, axis=1) # (batsize, lastdim) outs.append(out) if self.mode == "concat": ret = T.concatenate(outs, axis=1) elif self.mode == "seq": outs = [out.dimshuffle(0, "x", 1) for out in outs] ret = T.concatenate(outs, axis=1) return ret