def test_time_bin_causal_attention_n_bins(self):
   qkv_shape = ShapeDtype((3, 57, 8))
   input_signature = (qkv_shape, qkv_shape, qkv_shape)
   layer = attention.TimeBinCausalAttention(
       n_bins=4, dropout=0.1, mode='train')
   final_shape = base.check_shape_agreement(layer, input_signature)
   self.assertEqual((3, 57, 8), final_shape)
Пример #2
0
 def test_time_bin_causal_attention_bin_length(self):
   qkv_shape = (3, 57, 8)
   input_shape = (qkv_shape, qkv_shape, qkv_shape)
   layer = attention.TimeBinCausalAttention(
       bin_length=16, dropout=0.1, mode='train')
   final_shape = base.check_shape_agreement(layer, input_shape)
   self.assertEqual((3, 57, 8), final_shape)
  def test_time_bin_and_dot_product_causal_attention_are_consistent(self):
    dot_product_layer = attention.DotProductCausalAttention(
        dropout=0.0, mode='train')
    time_bin_layer = attention.TimeBinCausalAttention(
        bin_length=4, dropout=0.0, mode='train')

    # Exactly 2 bins.
    input_shape = (3, 8, 8)
    inputs = [onp.random.uniform(size=input_shape) for _ in range(3)]

    dot_product_output = dot_product_layer(inputs)
    time_bin_output = time_bin_layer(inputs)
    onp.testing.assert_array_almost_equal(dot_product_output, time_bin_output)