Exemplo n.º 1
0
class SeqTrans(Block):
    def __init__(self, embedder, *layers, **kw):
        super(SeqTrans, self).__init__(**kw)
        self.enc = SeqEncoder(embedder, *layers)
        self.enc.all_outputs().maskoption(MaskMode.NONE)

    def apply(self, x):
        return self.enc(x)
Exemplo n.º 2
0
class Seq2Vec(Block):
    def __init__(self, inpemb, enclayers, maskid=0, pool=None, **kw):
        super(Seq2Vec, self).__init__(**kw)
        self.maskid = maskid
        self.inpemb = inpemb
        if not issequence(enclayers):
            enclayers = [enclayers]
        self.pool = pool
        self.enc = SeqEncoder(inpemb,
                              *enclayers).maskoptions(maskid, MaskMode.AUTO)
        if self.pool is not None:
            self.enc = self.enc.all_outputs.with_mask

    def all_outputs(self):
        self.enc = self.enc.all_outputs()
        return self

    def apply(self, x, mask=None, weights=None):
        if self.pool is not None:
            ret, mask = self.enc(x, mask=mask, weights=weights)
            ret = self.pool(ret, mask)
        else:
            ret = self.enc(x, mask=mask, weights=weights)
        return ret
Exemplo n.º 3
0
class SimpleSeq2MultiVec(Block):
    def __init__(self,
                 indim=400,
                 inpembdim=50,
                 inpemb=None,
                 mode="concat",
                 innerdim=100,
                 numouts=1,
                 maskid=0,
                 bidir=False,
                 maskmode=MaskMode.NONE,
                 **kw):
        super(SimpleSeq2MultiVec, self).__init__(**kw)
        if inpemb is None:
            if inpembdim is None:
                inpemb = IdxToOneHot(indim)
                inpembdim = indim
            else:
                inpemb = VectorEmbed(indim=indim, dim=inpembdim)
        elif inpemb is False:
            inpemb = None
        else:
            inpembdim = inpemb.outdim
        if not issequence(innerdim):
            innerdim = [innerdim]
        innerdim[-1] += numouts
        rnn, lastdim = self.makernu(inpembdim, innerdim, bidir=bidir)
        self.outdim = lastdim * numouts if mode == "concat" else lastdim
        self.maskid = maskid
        self.inpemb = inpemb
        self.numouts = numouts
        self.mode = mode
        self.bidir = bidir
        if not issequence(rnn):
            rnn = [rnn]
        self.enc = SeqEncoder(inpemb, *rnn).maskoptions(maskid, maskmode)
        self.enc.all_outputs()

    @staticmethod
    def makernu(inpembdim, innerdim, bidir=False):
        return MakeRNU.make(inpembdim, innerdim, bidir=bidir)

    def apply(self, x, mask=None, weights=None):
        ret = self.enc(x, mask=mask,
                       weights=weights)  # (batsize, seqlen, lastdim)
        outs = []
        # apply mask    (SeqEncoder should attach mask to outvar if all_outputs()
        mask = mask if mask is not None else ret.mask if hasattr(
            ret, "mask") else None
        if self.bidir:
            mid = ret.shape[2] / 2
            ret1 = ret[:, :, :mid]
            ret2 = ret[:, :, mid:]
            ret = ret1
        for i in range(self.numouts):
            selfweights = ret[:, :, i]  # (batsize, seqlen)
            if self.bidir:
                selfweights += ret2[:, :, i]
            selfweights = Softmax()(selfweights)
            if mask is not None:
                selfweights *= mask  # apply mask
            selfweights = selfweights / T.sum(selfweights, axis=1).dimshuffle(
                0, "x")  # renormalize
            weightedstates = ret[:, :, self.numouts:] * selfweights.dimshuffle(
                0, 1, "x")
            if self.bidir:
                weightedstates2 = ret2[:, :,
                                       self.numouts:] * selfweights.dimshuffle(
                                           0, 1, "x")
                weightedstates = T.concatenate(
                    [weightedstates, weightedstates2], axis=2)
            out = T.sum(weightedstates, axis=1)  # (batsize, lastdim)
            outs.append(out)
        if self.mode == "concat":
            ret = T.concatenate(outs, axis=1)
        elif self.mode == "seq":
            outs = [out.dimshuffle(0, "x", 1) for out in outs]
            ret = T.concatenate(outs, axis=1)
        return ret