Beispiel #1
0
    def test_various_calls(self):
        list_kwargs = []
        for share_qk in [True, False]:
            for output in ['none', 'mult', 'conv', 'multconv']:
                for concat in ['original', 'fixed', 'none']:
                    kwargs = {
                        'share_qk': share_qk,
                        'output_layer_type': output,
                        'v_concat_type': concat
                    }
                    list_kwargs.append(kwargs)
        for kwargs in list_kwargs:
            layer = sparsity.MultiplicativeConvCausalAttention(d_feature=4,
                                                               n_heads=2,
                                                               sparsity=2,
                                                               **kwargs)
            x = np.array([[
                [2, 5, 3, 4],
                [0, 1, 2, 3],
                [0, 1, 2, 3],
            ]])
            _, _ = layer.init(shapes.signature(x))

            y = layer(x)
            self.assertEqual(y.shape, (1, 3, 4))
Beispiel #2
0
  def test_simple_call(self):
    layer = sparsity.MultiplicativeConvCausalAttention(
        d_feature=4, n_heads=2, sparsity=2)
    x = np.array([[[2, 5, 3, 4],
                   [0, 1, 2, 3],
                   [0, 1, 2, 3],]])
    _, _ = layer.init(shapes.signature(x))

    y = layer(x)
    self.assertEqual(y.shape, (1, 3, 4))