def test_time_bin_causal_attention_n_bins(self): qkv_shape = ShapeDtype((3, 57, 8)) input_signature = (qkv_shape, qkv_shape, qkv_shape) layer = efficient_attention.TimeBinCausalAttention( n_bins=4, dropout=0.1, mode='train') final_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual((3, 57, 8), final_shape)
def test_time_bin_and_dot_product_causal_attention_are_consistent(self): dot_product_layer = attention.DotProductCausalAttention( dropout=0.0, mode='train') time_bin_layer = efficient_attention.TimeBinCausalAttention( bin_length=4, dropout=0.0, mode='train') # Exactly 2 bins. input_shape = (3, 8, 8) inputs = [onp.random.uniform(size=input_shape) for _ in range(3)] dot_product_output = dot_product_layer(inputs) time_bin_output = time_bin_layer(inputs) onp.testing.assert_array_almost_equal(dot_product_output, time_bin_output)