def test_shape(self):
     batsize = 10
     ldim = 5
     rdim = 4
     aggdim = 7
     l = np.random.random((batsize, ldim))
     r = np.random.random((batsize, rdim))
     b = LinearDistance(ldim + rdim, aggdim)
     pred = b.predict(l, r)
     self.assertEqual(pred.shape, (batsize,))
 def test_shape(self):
     batsize = 10
     ldim = 5
     rdim = 4
     aggdim = 7
     l = np.random.random((batsize, ldim))
     r = np.random.random((batsize, rdim))
     b = LinearDistance(ldim + rdim, aggdim)
     pred = b.predict(l, r)
     self.assertEqual(pred.shape, (batsize, ))
示例#3
0
    def do_test_shapes(self, bidir=False, sepatt=False, rnu=GRU):
        inpvocsize = 100
        outvocsize = 13
        inpembdim = 10
        outembdim = 7
        encdim = [26, 14]
        decdim = [21, 15]
        batsize = 11
        inpseqlen = 6
        outseqlen = 5

        if bidir:
            encdim = [e / 2 for e in encdim]

        m = SimpleSeqEncDecAtt(inpvocsize=inpvocsize,
                               inpembdim=inpembdim,
                               outvocsize=outvocsize,
                               outembdim=outembdim,
                               encdim=encdim,
                               decdim=decdim,
                               bidir=bidir,
                               statetrans=True,
                               attdist=LinearDistance(15, 14, 17),
                               sepatt=sepatt,
                               rnu=rnu)

        inpseq = np.random.randint(0, inpvocsize,
                                   (batsize, inpseqlen)).astype("int32")
        outseq = np.random.randint(0, outvocsize,
                                   (batsize, outseqlen)).astype("int32")

        predenco, enco, states = m.enc.predict(inpseq)
        self.assertEqual(
            predenco.shape,
            (batsize, encdim[-1] if not bidir else encdim[-1] * 2))
        if rnu == GRU:
            self.assertEqual(len(states), 2)
            for state, encdime in zip(states, encdim):
                self.assertEqual(state.shape,
                                 (batsize, inpseqlen,
                                  encdime if not bidir else encdime * 2))
        elif rnu == LSTM:
            self.assertEqual(len(states), 4)
            for state, encdime in zip(
                    states, [encdim[0], encdim[0], encdim[1], encdim[1]]):
                self.assertEqual(state.shape,
                                 (batsize, inpseqlen,
                                  encdime if not bidir else encdime * 2))

        if sepatt:
            self.assertEqual(enco.shape,
                             (batsize, inpseqlen, 2,
                              encdim[-1] if not bidir else encdim[-1] * 2))

        pred = m.predict(inpseq, outseq)
        self.assertEqual(pred.shape, (batsize, outseqlen, outvocsize))

        _, outvar = m.autobuild(inpseq, outseq)
        for p in sorted(outvar[0].allparams, key=lambda x: str(x)):
            print p
示例#4
0
 def test_set_lr(self):
     attdist = LinearDistance(110, 110, 100)
     encdec = SimpleSeqEncDecAtt(inpvocsize=19,
                                 outvocsize=17,
                                 outconcat=False,
                                 encdim=110,
                                 decdim=110,
                                 attdist=attdist)
     encdec.dec.set_lr(0.1)
     encdec.dec.attention.set_lr(0.5)  # TODO
     encdata = np.random.randint(0, 19, (2, 5))
     decdata = np.random.randint(0, 17, (2, 5))
     o = encdec(Val(encdata), Val(decdata))
     #print "\n".join(["{}: {}".format(x, x.lrmul) for x in o.allparams])
     #print "\n".join(["{}: {}".format(x, x.lrmul) for x in o.allparams])
     encparams = encdec.enc.get_params()
     decparams = encdec.dec.get_params()
     attparams = encdec.dec.attention.get_params()
     print "\n".join(["{}: {}".format(x, x.lrmul)
                      for x in encparams]) + "\n"
     print "\n".join(["{}: {}".format(x, x.lrmul)
                      for x in decparams]) + "\n"
     for x in encparams:
         self.assertEqual(x.lrmul, 1.0)
     for x in decparams:
         if x not in attparams:
             self.assertEqual(x.lrmul, 0.1)
         else:
             self.assertEqual(x.lrmul, 0.5)
示例#5
0
    def test_get_params(self):
        attdist = LinearDistance(110, 110, 100)
        encdec = SimpleSeqEncDecAtt(inpvocsize=19,
                                    outvocsize=17,
                                    outconcat=False,
                                    encdim=(110, 100),
                                    decdim=100,
                                    attdist=attdist)
        enclayers = encdec.enc.block.layers
        params = set()
        for layer in enclayers:
            for paramname in "w wm whf u um uhf b bm bhf".split(
            ):  # GRU params
                params.add(getattr(layer, paramname))
        declayers = encdec.dec.block.layers
        for layer in declayers:
            for paramname in "w wm whf u um uhf b bm bhf".split(
            ):  # GRU params
                params.add(getattr(layer, paramname))
        params.update({encdec.dec.lin.W, encdec.dec.lin.b})

        params.update({
            encdec.dec.attention.attentiongenerator.dist.lin.W,
            encdec.dec.attention.attentiongenerator.dist.lin.b,
            encdec.dec.attention.attentiongenerator.dist.lin2.W,
            encdec.dec.attention.attentiongenerator.dist.lin2.b,
            encdec.dec.attention.attentiongenerator.dist.agg
        })
        self.assertEqual(params, encdec.get_params())
示例#6
0
 def test_shapes(self):
     batsize, seqlen, datadim, critdim, attdim = 5, 3, 4, 3, 7
     crit = np.random.random((batsize, critdim))
     data = np.random.random((batsize, seqlen, datadim))
     m = AttGen(LinearDistance(critdim, datadim, attdim))
     pred = m.predict(crit, data)
     self.assertEqual(pred.shape, (batsize, seqlen))
     self.assertTrue(np.allclose(np.sum(pred, axis=1), np.ones((pred.shape[0],))))
 def test_seq(self):
     batsize = 1
     ldim = 3
     rdim = 2
     seqlen = 4
     np.random.seed(544)
     l = np.random.random((batsize, ldim))
     r = np.random.random((batsize, seqlen, rdim))
     b = LinearDistance(ldim, rdim, 5)
     pred = b.predict(l, r)
     l = l[0]
     r = r[0]
     for i in range(seqlen):
         x = np.dot(l,
                    b.lin.W.value.get_value()) + b.lin.b.value.get_value()
         y = np.dot(
             r[i], b.lin2.W.value.get_value()) + b.lin2.b.value.get_value()
         z = np.dot(x + y, b.agg.value.get_value())
         self.assertTrue(np.isclose(z, pred[0][i]))
 def test_get_params(self):
     d = LinearDistance(10, 10, 10)
     params = {d.lin.W, d.lin.b, d.lin2.W, d.lin2.b, d.agg}
     self.assertEqual(params, d.get_params())