def test_create_mask_layer(self): layer = ra.AttentionMaskLayer() xs = np.zeros((1, 2, 5)) layer.init(shapes.signature(xs)) mask = layer(xs) self.assertEqual(mask.shape, (2, 2)) np.testing.assert_equal(tl.to_list(mask), [[True, False], [True, True]])
def test_create_mask_layer(self): layer = ra.AttentionMaskLayer() xs = _get_xs(q=2, k=2) layer.init(shapes.signature(xs)) _, _, _, mask = layer(xs) self.assertEqual(mask.shape, (1, 1, 2, 2)) np.testing.assert_equal(tl.to_list(mask), [[[[True, False], [True, True]]]])
def test_create_mask_layer_predict(self): layer = ra.AttentionMaskLayer(total_kv_pooling=2, n_raw_tokens_generated=1, max_inference_length=3, mode='predict') xs = np.zeros((1, 1, 5)) layer.init(shapes.signature(xs)) for _ in range(2): mask = layer(xs) self.assertEqual(mask.shape, (1, 3)) np.testing.assert_equal(tl.to_list(mask), [[True, False, False]]) for _ in range(2): mask = layer(xs) self.assertEqual(mask.shape, (1, 3)) np.testing.assert_equal(tl.to_list(mask), [[True, True, False]]) for _ in range(2): mask = layer(xs) self.assertEqual(mask.shape, (1, 3)) np.testing.assert_equal(tl.to_list(mask), [[True, True, True]])