def test_output_shape(self):
     batsize = 100
     seqlen = 5
     dim = 50
     indim = 13
     m = SeqEncoder(IdxToOneHot(13), GRU(dim=indim, innerdim=dim))
     data = np.random.randint(0, indim, (batsize, seqlen)).astype("int32")
     mpred = m.predict(data)
     self.assertEqual(mpred.shape, (batsize, dim))
 def test_output_shape_w_mask(self):
     batsize = 2
     seqlen = 5
     dim = 3
     indim = 7
     m = SeqEncoder(IdxToOneHot(indim), GRU(dim=indim, innerdim=dim)).all_outputs
     data = np.random.randint(0, indim, (batsize, seqlen)).astype("int32")
     mask = np.zeros_like(data).astype("float32")
     mask[:, 0:2] = 1
     weights = np.ones_like(data).astype("float32")
     mpred = m.predict(data, weights, mask)
     self.assertEqual(mpred.shape, (batsize, seqlen, dim))
 def test_mask_no_state_updates(self):
     batsize = 10
     seqlen = 3
     dim = 7
     indim = 5
     m = SeqEncoder(IdxToOneHot(indim), GRU(dim=indim, innerdim=dim)).maskoption(-1).all_outputs
     data = np.random.randint(0, indim, (batsize, seqlen)).astype("int32")
     data[:, 1] = 0
     ndata = np.ones_like(data) * -1
     data = np.concatenate([data, ndata], axis=1)
     pred = m.predict(data)
     for i in range(1, pred.shape[1]):
         print np.linalg.norm(pred[:, i - 1, :] - pred[:, i, :])
         if i < seqlen:
             self.assertTrue(not np.allclose(pred[:, i - 1, :], pred[:, i, :]))
         else:
             self.assertTrue(np.allclose(pred[:, i - 1, :], pred[:, i, :]))
Exemple #4
0
 def setUp(self):
     dim = 50
     self.outdim = 100
     batsize = 1000
     seqlen = 19
     self.enc = SeqEncoder(None, GRU(dim=dim, innerdim=self.outdim))
     self.enc = self.doswitches(self.enc)
     self.data = np.random.random((batsize, seqlen, dim)).astype("float32")
     self.out = self.enc.predict(self.data)
 def test_mask_zero_mask_with_custom_maskid(self):
     batsize = 10
     seqlen = 3
     dim = 7
     indim = 5
     m = SeqEncoder(IdxToOneHot(indim), GRU(dim=indim, innerdim=dim)).maskoptions(-1, MaskSetMode.ZERO).all_outputs
     data = np.random.randint(0, indim, (batsize, seqlen)).astype("int32")
     data[:, 1] = 0
     ndata = np.ones_like(data) * -1
     data = np.concatenate([data, ndata], axis=1)
     pred = m.predict(data)
     for i in range(pred.shape[1]):
         print np.linalg.norm(pred[:, i - 1, :] - pred[:, i, :])
         if i < seqlen:
             for j in range(pred.shape[0]):
                 self.assertTrue(np.linalg.norm(pred[j, i, :]) > 0.0)
         else:
             for j in range(pred.shape[0]):
                 self.assertTrue(np.linalg.norm(pred[j, i, :]) == 0.0)
Exemple #6
0
class SimpleRNNEncoderTest(TestCase):
    expectedparams = ["um", "wm", "uhf", "whf", "u", "w", "bm", "bhf", "b"]
    expectednumberparams = 1

    def setUp(self):
        dim = 50
        self.outdim = 100
        batsize = 1000
        seqlen = 19
        self.enc = SeqEncoder(None, GRU(dim=dim, innerdim=self.outdim))
        self.enc = self.doswitches(self.enc)
        self.data = np.random.random((batsize, seqlen, dim)).astype("float32")
        self.out = self.enc.predict(self.data)

    def test_output_shape(self):
        self.assertEqual(self.out.shape, self.expectshape(self.data.shape, self.outdim))

    def expectshape(self, datashape, outdim):
        return (datashape[0], outdim)

    def test_all_output_parameters(self):
        outputs = self.enc.wrapply(*self.enc.inputs)
        if issequence(outputs) and len(outputs) > 1:
            outputparamsets = [x.allparams for x in outputs]
            for i in range(len(outputparamsets)):
                for j in range(i, len(outputparamsets)):
                    self.assertSetEqual(outputparamsets[i], outputparamsets[j])
        if issequence(outputs):
            outputs = outputs[0]
        outputparamcounts = {}
        for paramname in [x.name for x in outputs.allparams]:
            if paramname not in outputparamcounts:
                outputparamcounts[paramname] = 0
            outputparamcounts[paramname] += 1
        for (_, y) in outputparamcounts.items():
            self.assertEqual(y, self.expectednumberparams)
        self.assertSetEqual(set(outputparamcounts.keys()), set(self.expectedparams))

    def doswitches(self, enc):
        return enc
Exemple #7
0
 def setUp(self):
     batsize = 1000
     seqlen = 19
     indim = 71
     hdim = 51
     hdim2 = 61
     self.outdim = 47
     self.enc = SeqEncoder(None,
                           GRU(dim=indim, innerdim=hdim),
                           GRU(dim=hdim, innerdim=hdim2),
                           GRU(dim=hdim2, innerdim=self.outdim))
     self.enc = self.doswitches(self.enc)
     self.data = np.random.random((batsize, seqlen, indim)).astype("float32")
     self.out = self.enc.predict(self.data)
 def test_memory_block_with_seq_encoder(self):
     invocabsize = 5
     memsize = 10
     seqlen = 3
     encdim = 13
     data = np.random.randint(0, invocabsize, (memsize, seqlen))
     gru = GRU(dim=invocabsize, innerdim=encdim)
     payload = SeqEncoder(IdxToOneHot(vocsize=invocabsize), gru)
     memb = MemoryBlock(payload, data, indim=invocabsize, outdim=encdim)
     idxs = [0, 2, 5]
     memory_element = memb.predict(idxs)
     self.assertEqual(memory_element.shape, (len(idxs), encdim))
     gruparams = set([getattr(gru, pname) for pname in gru.paramnames])
     allparams = set(memb.output.allparams)
     self.assertEqual(gruparams.intersection(allparams), allparams)
Exemple #9
0
 def __init__(self,
              indim=400,
              inpembdim=50,
              inpemb=None,
              mode="concat",
              innerdim=100,
              numouts=1,
              maskid=0,
              bidir=False,
              maskmode=MaskMode.NONE,
              **kw):
     super(SimpleSeq2MultiVec, self).__init__(**kw)
     if inpemb is None:
         if inpembdim is None:
             inpemb = IdxToOneHot(indim)
             inpembdim = indim
         else:
             inpemb = VectorEmbed(indim=indim, dim=inpembdim)
     elif inpemb is False:
         inpemb = None
     else:
         inpembdim = inpemb.outdim
     if not issequence(innerdim):
         innerdim = [innerdim]
     innerdim[-1] += numouts
     rnn, lastdim = self.makernu(inpembdim, innerdim, bidir=bidir)
     self.outdim = lastdim * numouts if mode == "concat" else lastdim
     self.maskid = maskid
     self.inpemb = inpemb
     self.numouts = numouts
     self.mode = mode
     self.bidir = bidir
     if not issequence(rnn):
         rnn = [rnn]
     self.enc = SeqEncoder(inpemb, *rnn).maskoptions(maskid, maskmode)
     self.enc.all_outputs()
Exemple #10
0
class SimpleRNNEncoderTest(TestCase):
    expectedparams = ["um", "wm", "uhf", "whf", "u", "w", "bm", "bhf", "b"]
    expectednumberparams = 1

    def setUp(self):
        dim = 50
        self.outdim = 100
        batsize = 1000
        seqlen = 19
        self.enc = SeqEncoder(None, GRU(dim=dim, innerdim=self.outdim))
        self.enc = self.doswitches(self.enc)
        self.data = np.random.random((batsize, seqlen, dim)).astype("float32")
        self.p = self.enc.predict
        self.out = self.p(self.data)

    def test_output_shape(self):
        out = self.out
        if isinstance(self.out, tuple):
            out = self.out[0]
        self.assertEqual(out.shape, self.expectshape(self.data.shape, self.outdim))

    def expectshape(self, datashape, outdim):
        return (datashape[0], outdim)

    def test_all_output_parameters(self):
        outputs = self.enc.wrapply(*self.p.inps)
        if issequence(outputs) and len(outputs) > 1:
            outputparamsets = [x.allparams for x in outputs if isinstance(x, (Var, Val))]
            for i in range(len(outputparamsets)):
                for j in range(i, len(outputparamsets)):
                    self.assertSetEqual(outputparamsets[i], outputparamsets[j])
        if issequence(outputs):
            outputs = outputs[0]
        outputparamcounts = {}
        for paramname in [x.name for x in outputs.allparams]:
            if paramname not in outputparamcounts:
                outputparamcounts[paramname] = 0
            outputparamcounts[paramname] += 1
        for (_, y) in outputparamcounts.items():
            self.assertEqual(y, self.expectednumberparams)
        self.assertSetEqual(set(outputparamcounts.keys()), set(self.expectedparams))

    def doswitches(self, enc):
        return enc
Exemple #11
0
    def __init__(self,  entembdim=50,
                        wordembdim=50,
                        wordencdim=100,
                        memdata=None,
                        attdim=100,
                        numchars=128,       # number of different chars
                        numwords=4e5,       # number of different words
                        glovepath=None,
                        innerdim=100,       # dim of memory payload encoder output
                        outdim=1e4,         # number of entities
                        memaddr=DotMemAddr, **kw):
        super(FBMemMatch, self).__init__(**kw)
        self.wordembdim = wordembdim
        self.wordencdim = wordencdim
        self.entembdim = entembdim
        self.attdim = attdim
        self.encinnerdim = innerdim
        self.outdim = outdim

        memaddr = TransDotMemAddr

        # memory encoder per word
        #wencpg = WordEmbed(indim=numwords, outdim=self.wordembdim, trainfrac=1.0)
        wordencoder = WordEncoderPlusGlove(numchars=numchars, numwords=numwords, encdim=self.wordencdim,
                                      embdim=self.wordembdim, embtrainfrac=0.0, glovepath=glovepath)

        # memory encoder for one cell
        self.phraseencoder = SeqEncoder(
            wordencoder,
            GRU(dim=self.wordembdim + self.wordencdim,
                innerdim=self.encinnerdim)
        )
        # entity embedder
        entemb = VectorEmbed(indim=self.outdim, dim=self.entembdim)
        self.entembs = entemb(memdata[0]) #Val(np.arange(0, self.outdim, dtype="int32")))
        # memory block
        self.mempayload = self.phraseencoder #ConcatBlock(entemb, self.phraseencoder)
        self.memblock = MemoryBlock(self.mempayload, memdata[1], indim=self.outdim,
                                    outdim=self.encinnerdim)# + self.entembdim)
        # memory addressing
        self.mema = memaddr(self.memblock,
                       memdim=self.memblock.outdim, attdim=attdim, indim=self.encinnerdim)
Exemple #12
0
    def __init__(self, wordembdim=50, wordencdim=100, entembdim=200, innerdim=200, outdim=1e4, numwords=4e5, numchars=128, glovepath=None, **kw):
        super(FBSeqCompositeEncDec, self).__init__(**kw)
        self.indim = wordembdim + wordencdim
        self.outdim = outdim
        self.wordembdim = wordembdim
        self.wordencdim = wordencdim
        self.encinnerdim = innerdim
        self.entembdim = entembdim
        self.decinnerdim = innerdim

        self.enc = SeqEncoder(
            WordEncoderPlusGlove(numchars=numchars, numwords=numwords, encdim=self.wordencdim, embdim=self.wordembdim, embtrainfrac=0.0, glovepath=glovepath),
            GRU(dim=self.wordembdim + self.wordencdim, innerdim=self.encinnerdim)
        )

        self.dec = SeqDecoder(
            [VectorEmbed(indim=self.outdim, dim=self.entembdim), GRU(dim=self.entembdim+self.encinnerdim, innerdim=self.decinnerdim)],
            inconcat=True,
            innerdim=self.decinnerdim,
        )
Exemple #13
0
class StackRNNEncoderTest(SimpleRNNEncoderTest):
    expectednumberparams = 3

    def setUp(self):
        batsize = 1000
        seqlen = 19
        indim = 71
        hdim = 51
        hdim2 = 61
        self.outdim = 47
        self.enc = SeqEncoder(None,
                              GRU(dim=indim, innerdim=hdim),
                              GRU(dim=hdim, innerdim=hdim2),
                              GRU(dim=hdim2, innerdim=self.outdim))
        self.enc = self.doswitches(self.enc)
        self.data = np.random.random((batsize, seqlen, indim)).astype("float32")
        self.out = self.enc.predict(self.data)

    def doswitches(self, enc):
        return enc
Exemple #14
0
    def test_memory_block_with_seq_encoder_dynamic(self):
        invocabsize = 5
        memsize = 10
        seqlen = 3
        encdim = 13
        data = np.random.randint(0, invocabsize, (memsize, seqlen))
        gru = GRU(dim=invocabsize, innerdim=encdim)
        payload = SeqEncoder(IdxToOneHot(vocsize=invocabsize), gru)
        dynmemb = MemoryBlock(payload, outdim=encdim)
        idxs = [0, 2, 5]
        p = dynmemb.predict
        memory_element = p(idxs, data)
        self.assertEqual(memory_element.shape, (len(idxs), encdim))
        gruparams = set([
            getattr(gru, pname)
            for pname in "u w b uhf whf bhf um wm bm".split()
        ])
        allparams = set(p.outs[0].allparams)
        self.assertEqual(gruparams.intersection(allparams), allparams)

        statmemb = MemoryBlock(payload, data, outdim=encdim)
        statpred = statmemb.predict(idxs)
        self.assertTrue(np.allclose(statpred, memory_element))
Exemple #15
0
    def init(self):
        #MEMORY: encodes how entity is written + custom entity embeddings
        wencpg = WordEncoderPlusGlove(numchars=self.numchars, numwords=self.numwords, encdim=self.wordencdim, embdim=self.wordembdim, embtrainfrac=0.0, glovepath=self.glovepath)
        self.memenco = SeqEncoder(
            wencpg,
            GRU(dim=self.wordembdim + self.wordencdim, innerdim=self.encinnerdim)
        )

        entemb = VectorEmbed(indim=self.outdim, dim=self.entembdim)
        self.mempayload = ConcatBlock(entemb, self.memenco)
        self.memblock = MemoryBlock(self.mempayload, self.memdata, indim=self.outdim, outdim=self.encinnerdim+self.entembdim)

        #ENCODER: uses the same language encoder as memory
        #wencpg2 = WordEncoderPlusGlove(numchars=self.numchars, numwords=self.numwords, encdim=self.wordencdim, embdim=self.wordembdim, embtrainfrac=0.0, glovepath=glovepath)
        self.enc = RecStack(wencpg, GRU(dim=self.wordembdim + self.wordencdim, innerdim=self.encinnerdim))

        #ATTENTION
        attgen = LinearGateAttentionGenerator(indim=self.encinnerdim + self.decinnerdim, innerdim=self.attdim)
        attcon = WeightedSumAttCon()

        #DECODER
        #entemb2 = VectorEmbed(indim=self.outdim, dim=self.entembdim)
        self.softmaxoutblock = stack(
            self.memaddr(
                self.memblock,
                indim=self.decinnerdim + self.encinnerdim,
                memdim=self.memblock.outdim,
                attdim=self.attdim),
            Softmax())

        self.dec = SeqDecoder(
            [self.memblock, GRU(dim=self.entembdim + self.encinnerdim, innerdim=self.decinnerdim)],
            outconcat=True, inconcat=False,
            attention=Attention(attgen, attcon),
            innerdim=self.decinnerdim + self.encinnerdim,
            softmaxoutblock=self.softmaxoutblock
        )
Exemple #16
0
class Seq2Vec(Block):
    def __init__(self, inpemb, enclayers, maskid=0, pool=None, **kw):
        super(Seq2Vec, self).__init__(**kw)
        self.maskid = maskid
        self.inpemb = inpemb
        if not issequence(enclayers):
            enclayers = [enclayers]
        self.pool = pool
        self.enc = SeqEncoder(inpemb,
                              *enclayers).maskoptions(maskid, MaskMode.AUTO)
        if self.pool is not None:
            self.enc = self.enc.all_outputs.with_mask

    def all_outputs(self):
        self.enc = self.enc.all_outputs()
        return self

    def apply(self, x, mask=None, weights=None):
        if self.pool is not None:
            ret, mask = self.enc(x, mask=mask, weights=weights)
            ret = self.pool(ret, mask)
        else:
            ret = self.enc(x, mask=mask, weights=weights)
        return ret
Exemple #17
0
 def __init__(self, embedder, *layers, **kw):
     super(SeqTrans, self).__init__(**kw)
     self.enc = SeqEncoder(embedder, *layers)
     self.enc.all_outputs().maskoption(MaskMode.NONE)
Exemple #18
0
    def __init__(self,
                 inpvocsize=None,
                 inpembdim=None,
                 inpemb=None,
                 inpencinnerdim=None,
                 bidir=False,
                 maskid=None,
                 dropout=False,
                 rnu=GRU,
                 inpencoder=None,
                 memvocsize=None,
                 memembdim=None,
                 memembmat=None,
                 memencinnerdim=None,
                 memencoder=None,
                 inp_att_dist=CosineDistance(),
                 mem_att_dist=CosineDistance(),
                 inp_attention=None,
                 mem_attention=None,
                 coredims=None,
                 corernu=GRU,
                 core=None,
                 explicit_interface=False,
                 scalaraggdim=None,
                 write_value_dim=None,
                 nsteps=100,
                 posvecdim=None,
                 mem_pos_repr=None,
                 inp_pos_repr=None,
                 inp_addr_extractor=None,
                 mem_addr_extractor=None,
                 write_addr_extractor=None,
                 write_addr_generator=None,
                 write_addr_dist=CosineDistance(),
                 write_value_generator=None,
                 write_value_extractor=None,
                 mem_erase_generator=None,
                 mem_change_generator=None,
                 memsampler=None,
                 memsamplemethod=None,
                 memsampletemp=0.3,
                 **kw):

        # INPUT ENCODING
        if inpencoder is None:
            inpencoder = SeqEncoder.RNN(indim=inpvocsize,
                                        inpembdim=inpembdim,
                                        inpemb=inpemb,
                                        innerdim=inpencinnerdim,
                                        bidir=bidir,
                                        maskid=maskid,
                                        dropout_in=dropout,
                                        dropout_h=dropout,
                                        rnu=rnu).all_outputs()
            lastinpdim = inpencinnerdim if not issequence(
                inpencinnerdim) else inpencinnerdim[-1]
        else:
            lastinpdim = inpencoder.block.layers[-1].innerdim

        # MEMORY ENCODING
        if memembmat is None:
            memembmat = param((memvocsize, memembdim),
                              name="memembmat").glorotuniform()
        if memencoder is None:
            memencoder = SeqEncoder.RNN(inpemb=False,
                                        innerdim=memencinnerdim,
                                        bidir=bidir,
                                        dropout_in=dropout,
                                        dropout_h=dropout,
                                        rnu=rnu,
                                        inpembdim=memembdim).all_outputs()
            lastmemdim = memencinnerdim if not issequence(
                memencinnerdim) else memencinnerdim[-1]
        else:
            lastmemdim = memencoder.block.layers[-1].innerdim

        # POSITION VECTORS
        if posvecdim is not None and inp_pos_repr is None:
            inp_pos_repr = RNNWithoutInput(posvecdim, dropout=dropout)
        if posvecdim is not None and mem_pos_repr is None:
            mem_pos_repr = RNNWithoutInput(posvecdim, dropout=dropout)

        xtra_dim = posvecdim if posvecdim is not None else 0
        # CORE RNN - THE THINKER
        if core is None:
            corelayers, _ = MakeRNU.fromdims(
                [lastinpdim + lastmemdim + xtra_dim * 2] + coredims,
                rnu=corernu,
                dropout_in=dropout,
                dropout_h=dropout,
                param_init_states=True)
            core = RecStack(*corelayers)

        lastcoredim = core.get_statespec()[-1][0][1][0]

        # ATTENTIONS
        if mem_attention is None:
            mem_attention = Attention(mem_att_dist)
        if inp_attention is None:
            inp_attention = Attention(inp_att_dist)
        if write_addr_generator is None:
            write_addr_generator = AttGen(write_addr_dist)

        # WRITE VALUE
        if write_value_generator is None:
            write_value_generator = WriteValGenerator(write_value_dim,
                                                      memvocsize,
                                                      dropout=dropout)

        # MEMORY SAMPLER
        if memsampler is not None:
            assert (memsamplemethod is None)
        if memsamplemethod is not None:
            assert (memsampler is None)
            memsampler = GumbelSoftmax(temperature=memsampletemp)

        ################ STATE INTERFACES #################

        if not explicit_interface:
            if inp_addr_extractor is None:
                inp_addr_extractor = Forward(lastcoredim,
                                             lastinpdim + xtra_dim,
                                             dropout=dropout)
            if mem_addr_extractor is None:
                inp_addr_extractor = Forward(lastcoredim,
                                             lastmemdim + xtra_dim,
                                             dropout=dropout)

            # WRITE INTERFACE
            if write_addr_extractor is None:
                write_addr_extractor = Forward(lastcoredim,
                                               lastmemdim + xtra_dim,
                                               dropout=dropout)
            if write_value_extractor is None:
                write_value_extractor = Forward(lastcoredim,
                                                write_value_dim,
                                                dropout=dropout)

            # MEM UPDATE INTERFACE
            if mem_erase_generator is None:
                mem_erase_generator = StateToScalar(lastcoredim, scalaraggdim)
            if mem_change_generator is None:
                mem_change_generator = StateToScalar(lastcoredim, scalaraggdim)
        else:
            inp_addr_extractor, mem_addr_extractor, write_addr_extractor, \
            write_value_extractor, mem_erase_generator, mem_change_generator = \
                make_vector_slicers(0, lastinpdim + xtra_dim, lastmemdim + xtra_dim,
                                    lastmemdim + xtra_dim, write_value_dim, 1, 1)

        super(SimpleBulkNN,
              self).__init__(inpencoder=inpencoder,
                             memembmat=memembmat,
                             memencoder=memencoder,
                             inp_attention=inp_attention,
                             mem_attention=mem_attention,
                             core=core,
                             memsampler=memsampler,
                             nsteps=nsteps,
                             inp_addr_extractor=inp_addr_extractor,
                             mem_addr_extractor=mem_addr_extractor,
                             write_addr_extractor=write_addr_extractor,
                             write_addr_generator=write_addr_generator,
                             mem_erase_generator=mem_erase_generator,
                             mem_change_generator=mem_change_generator,
                             write_value_generator=write_value_generator,
                             write_value_extractor=write_value_extractor,
                             inp_pos_repr=inp_pos_repr,
                             mem_pos_repr=mem_pos_repr,
                             **kw)
Exemple #19
0
class SimpleSeq2MultiVec(Block):
    def __init__(self,
                 indim=400,
                 inpembdim=50,
                 inpemb=None,
                 mode="concat",
                 innerdim=100,
                 numouts=1,
                 maskid=0,
                 bidir=False,
                 maskmode=MaskMode.NONE,
                 **kw):
        super(SimpleSeq2MultiVec, self).__init__(**kw)
        if inpemb is None:
            if inpembdim is None:
                inpemb = IdxToOneHot(indim)
                inpembdim = indim
            else:
                inpemb = VectorEmbed(indim=indim, dim=inpembdim)
        elif inpemb is False:
            inpemb = None
        else:
            inpembdim = inpemb.outdim
        if not issequence(innerdim):
            innerdim = [innerdim]
        innerdim[-1] += numouts
        rnn, lastdim = self.makernu(inpembdim, innerdim, bidir=bidir)
        self.outdim = lastdim * numouts if mode == "concat" else lastdim
        self.maskid = maskid
        self.inpemb = inpemb
        self.numouts = numouts
        self.mode = mode
        self.bidir = bidir
        if not issequence(rnn):
            rnn = [rnn]
        self.enc = SeqEncoder(inpemb, *rnn).maskoptions(maskid, maskmode)
        self.enc.all_outputs()

    @staticmethod
    def makernu(inpembdim, innerdim, bidir=False):
        return MakeRNU.make(inpembdim, innerdim, bidir=bidir)

    def apply(self, x, mask=None, weights=None):
        ret = self.enc(x, mask=mask,
                       weights=weights)  # (batsize, seqlen, lastdim)
        outs = []
        # apply mask    (SeqEncoder should attach mask to outvar if all_outputs()
        mask = mask if mask is not None else ret.mask if hasattr(
            ret, "mask") else None
        if self.bidir:
            mid = ret.shape[2] / 2
            ret1 = ret[:, :, :mid]
            ret2 = ret[:, :, mid:]
            ret = ret1
        for i in range(self.numouts):
            selfweights = ret[:, :, i]  # (batsize, seqlen)
            if self.bidir:
                selfweights += ret2[:, :, i]
            selfweights = Softmax()(selfweights)
            if mask is not None:
                selfweights *= mask  # apply mask
            selfweights = selfweights / T.sum(selfweights, axis=1).dimshuffle(
                0, "x")  # renormalize
            weightedstates = ret[:, :, self.numouts:] * selfweights.dimshuffle(
                0, 1, "x")
            if self.bidir:
                weightedstates2 = ret2[:, :,
                                       self.numouts:] * selfweights.dimshuffle(
                                           0, 1, "x")
                weightedstates = T.concatenate(
                    [weightedstates, weightedstates2], axis=2)
            out = T.sum(weightedstates, axis=1)  # (batsize, lastdim)
            outs.append(out)
        if self.mode == "concat":
            ret = T.concatenate(outs, axis=1)
        elif self.mode == "seq":
            outs = [out.dimshuffle(0, "x", 1) for out in outs]
            ret = T.concatenate(outs, axis=1)
        return ret