class DummyAttentionGeneratorConsumerTest(TestCase):
    def setUp(self):
        criteriondim = 30
        datadim = 20
        innerdim = 25
        batsize = 33
        seqlen = 11
        self.attgenshape = (batsize, seqlen)
        self.attconshape = (batsize, datadim)
        self.attgenc = self.getattgenc()
        self.attgen = self.attgenc(indim=criteriondim + datadim,
                                   innerdim=innerdim)
        self.attgenparams = self.getattgenparams()
        self.attcon = WeightedSumAttCon()
        self.att = Attention(self.attgen, self.attcon)
        self.criterion_val = np.random.random(
            (batsize, criteriondim)).astype("float32")
        self.data_val = np.random.random(
            (batsize, seqlen, datadim)).astype("float32")

    def getattgenc(self):
        return LinearSumAttentionGenerator

    def getattgenparams(self):
        return {self.attgen.W}

    def test_generator_shape(self):
        pred = self.attgen.predict(self.criterion_val, self.data_val)
        self.assertEqual(pred.shape, self.attgenshape)

    def test_generator_param_prop(self):
        self.attgen.predict(self.criterion_val, self.data_val)
        allparams = self.attgen.output.allparams
        self.assertSetEqual(allparams, self.attgenparams)

    def test_consumer_shape(self):
        pred = self.att.predict(self.criterion_val, self.data_val)
        self.assertEqual(pred.shape, self.attconshape)

    def test_consumer_param_prop(self):
        self.att.predict(self.criterion_val, self.data_val)
        allparams = self.att.output.allparams
        self.assertSetEqual(allparams, self.attgenparams)
Beispiel #2
0
class DummyAttentionGeneratorConsumerTest(TestCase):
    def setUp(self):
        criteriondim = 30
        datadim = 20
        innerdim = 25
        batsize = 33
        seqlen = 11
        self.attgenshape = (batsize, seqlen)
        self.attconshape = (batsize, datadim)
        self.attgenc = self.getattgenc()
        self.attgen = self.attgenc(indim=criteriondim + datadim, innerdim=innerdim)
        self.attgenparams = self.getattgenparams()
        self.attcon = WeightedSumAttCon()
        self.att = Attention(self.attgen, self.attcon)
        self.criterion_val = np.random.random((batsize, criteriondim)).astype("float32")
        self.data_val = np.random.random((batsize, seqlen, datadim)).astype("float32")

    def getattgenc(self):
        return LinearSumAttentionGenerator

    def getattgenparams(self):
        return {self.attgen.W}

    def test_generator_shape(self):
        pred = self.attgen.predict(self.criterion_val, self.data_val)
        self.assertEqual(pred.shape, self.attgenshape)

    def test_generator_param_prop(self):
        self.attgen.predict(self.criterion_val, self.data_val)
        allparams = self.attgen.output.allparams
        self.assertSetEqual(allparams, self.attgenparams)

    def test_consumer_shape(self):
        pred = self.att.predict(self.criterion_val, self.data_val)
        self.assertEqual(pred.shape, self.attconshape)

    def test_consumer_param_prop(self):
        self.att.predict(self.criterion_val, self.data_val)
        allparams = self.att.output.allparams
        self.assertSetEqual(allparams, self.attgenparams)
Beispiel #3
0
class DummyAttentionGeneratorConsumerTest(TestCase):
    def setUp(self):
        criteriondim = 20
        datadim = 20
        innerdim = 30
        batsize = 33
        seqlen = 11
        self.attgenshape = (batsize, seqlen)
        self.attconshape = (batsize, datadim)
        self.attgen = self.getattgenc(critdim=criteriondim, datadim=datadim, attdim=innerdim)
        self.attgenparams = self.getattgenparams()
        self.attcon = WeightedSumAttCon()
        self.att = Attention(self.attgen, self.attcon)
        self.criterion_val = np.random.random((batsize, criteriondim)).astype("float32")
        self.data_val = np.random.random((batsize, seqlen, datadim)).astype("float32")

    def getattgenc(self, critdim=None, datadim=None, attdim=None):
        return AttGen(DotDistance())

    def getattgenparams(self):
        return set()

    def test_generator_shape(self):
        pred = self.attgen.predict(self.criterion_val, self.data_val)
        self.assertEqual(pred.shape, self.attgenshape)

    def test_generator_param_prop(self):
        _, outps = self.attgen.autobuild(self.criterion_val, self.data_val)
        allparams = outps[0].allparams
        self.assertSetEqual(allparams, self.attgenparams)

    def test_consumer_shape(self):
        pred = self.att.predict(self.criterion_val, self.data_val)
        self.assertEqual(pred.shape, self.attconshape)

    def test_consumer_param_prop(self):
        _, outps = self.att.autobuild(self.criterion_val, self.data_val)
        allparams = outps[0].allparams
        self.assertSetEqual(allparams, self.attgenparams)