コード例 #1
0
class TestAttentionRNNDecoder(TestCase):
    def setUp(self):
        vocsize = 10
        innerdim = 50
        encdim = 30
        seqlen = 5
        batsize = 77
        self.att = Attention(AttGen(BilinearDistance(innerdim, encdim)),
                             WeightedSumAttCon())
        self.decwatt = SeqDecoder(
            [IdxToOneHot(vocsize), GRU(dim=vocsize+encdim, innerdim=innerdim)],
            inconcat=True,
            attention=self.att,
            innerdim=innerdim
        )
        self.decwoatt = SeqDecoder(
            [IdxToOneHot(vocsize), GRU(dim=vocsize+encdim, innerdim=innerdim)],
            inconcat=True,
            innerdim=innerdim
        )
        self.attdata = np.random.random((batsize, seqlen, encdim)).astype("float32")
        self.data = np.random.random((batsize, encdim)).astype("float32")
        self.seqdata = np.random.randint(0, vocsize, (batsize, seqlen))
        self.predshape = (batsize, seqlen, vocsize)

    def test_shape(self):
        pred = self.decwatt.predict(self.attdata, self.seqdata)
        self.assertEqual(pred.shape, self.predshape)

    def test_shape_wo_att(self):
        pred = self.decwoatt.predict(self.data, self.seqdata)
        self.assertEqual(pred.shape, self.predshape)

    def test_attentiongenerator_param_in_allparams(self):
        inps, outps = self.decwatt.autobuild(self.attdata, self.seqdata)
        allparams = outps[0].allparams
        self.assertIn(self.att.attentiongenerator.dist.W, allparams)

    def test_attentiongenerator_param_not_in_params_of_dec_wo_att(self):
        _, outps = self.decwoatt.autobuild(self.data, self.seqdata)
        allparams = outps[0].allparams
        self.assertNotIn(self.att.attentiongenerator.dist.W, allparams)