Esempio n. 1
0
 def test_create_mask_layer(self):
     layer = ra.AttentionMaskLayer()
     xs = np.zeros((1, 2, 5))
     layer.init(shapes.signature(xs))
     mask = layer(xs)
     self.assertEqual(mask.shape, (2, 2))
     np.testing.assert_equal(tl.to_list(mask),
                             [[True, False], [True, True]])
Esempio n. 2
0
 def test_create_mask_layer(self):
     layer = ra.AttentionMaskLayer()
     xs = _get_xs(q=2, k=2)
     layer.init(shapes.signature(xs))
     _, _, _, mask = layer(xs)
     self.assertEqual(mask.shape, (1, 1, 2, 2))
     np.testing.assert_equal(tl.to_list(mask),
                             [[[[True, False], [True, True]]]])
Esempio n. 3
0
    def test_create_mask_layer_predict(self):
        layer = ra.AttentionMaskLayer(total_kv_pooling=2,
                                      n_raw_tokens_generated=1,
                                      max_inference_length=3,
                                      mode='predict')
        xs = np.zeros((1, 1, 5))
        layer.init(shapes.signature(xs))

        for _ in range(2):
            mask = layer(xs)
            self.assertEqual(mask.shape, (1, 3))
            np.testing.assert_equal(tl.to_list(mask), [[True, False, False]])

        for _ in range(2):
            mask = layer(xs)
            self.assertEqual(mask.shape, (1, 3))
            np.testing.assert_equal(tl.to_list(mask), [[True, True, False]])

        for _ in range(2):
            mask = layer(xs)
            self.assertEqual(mask.shape, (1, 3))
            np.testing.assert_equal(tl.to_list(mask), [[True, True, True]])