class TestAttentionRNNDecoder(TestCase): def setUp(self): vocsize = 10 innerdim = 50 encdim = 30 seqlen = 5 batsize = 77 self.att = Attention(AttGen(BilinearDistance(innerdim, encdim)), WeightedSumAttCon()) self.decwatt = SeqDecoder( [IdxToOneHot(vocsize), GRU(dim=vocsize+encdim, innerdim=innerdim)], inconcat=True, attention=self.att, innerdim=innerdim ) self.decwoatt = SeqDecoder( [IdxToOneHot(vocsize), GRU(dim=vocsize+encdim, innerdim=innerdim)], inconcat=True, innerdim=innerdim ) self.attdata = np.random.random((batsize, seqlen, encdim)).astype("float32") self.data = np.random.random((batsize, encdim)).astype("float32") self.seqdata = np.random.randint(0, vocsize, (batsize, seqlen)) self.predshape = (batsize, seqlen, vocsize) def test_shape(self): pred = self.decwatt.predict(self.attdata, self.seqdata) self.assertEqual(pred.shape, self.predshape) def test_shape_wo_att(self): pred = self.decwoatt.predict(self.data, self.seqdata) self.assertEqual(pred.shape, self.predshape) def test_attentiongenerator_param_in_allparams(self): inps, outps = self.decwatt.autobuild(self.attdata, self.seqdata) allparams = outps[0].allparams self.assertIn(self.att.attentiongenerator.dist.W, allparams) def test_attentiongenerator_param_not_in_params_of_dec_wo_att(self): _, outps = self.decwoatt.autobuild(self.data, self.seqdata) allparams = outps[0].allparams self.assertNotIn(self.att.attentiongenerator.dist.W, allparams)