def test_serialization(): """ Test layer serialization (get_config, from_config) """ simple = MultiHeadedSelfAttention() restored = MultiHeadedSelfAttention.from_config(simple.get_config()) assert restored.get_config() == simple.get_config()
def test_causality(): """ Test causality """ # Input input_shape_3d = (56, 10, 30) n_units_mh = 128 inputs_3d = tf.random.normal(shape=input_shape_3d) # Layers multi_headed_self_attention = MultiHeadedSelfAttention( n_heads=4, n_units=n_units_mh, causality=True, name='mh_self_attention') result = multi_headed_self_attention(inputs_3d) # Change last step inputs_3d_ = tf.concat([ inputs_3d[:, :-1, :], tf.random.normal((input_shape_3d[0], 1, input_shape_3d[-1])) ], axis=1) result_ = multi_headed_self_attention(inputs_3d_) # Assert last step is different but the rest not affected assert (result[:, :-1, :].numpy() == result_[:, :-1, :].numpy() ).all() # Without last step assert not (result.numpy() == result_.numpy()).all()
def test_shapes(): """ Test input-output shapes """ # Inputs shape input_shape_3d = (56, 10, 30) n_units_mh = 128 inputs_3d = tf.random.normal(shape=input_shape_3d) single_self_attention = MultiHeadedSelfAttention(n_heads=1, name='self_attention') multi_headed_self_attention = MultiHeadedSelfAttention( n_heads=4, n_units=n_units_mh, name='mh_self_attention') output_single, output_mh = single_self_attention( inputs_3d), multi_headed_self_attention(inputs_3d) # Assert correctness of output shapes assert output_single.shape == input_shape_3d assert output_mh.shape == input_shape_3d[:-1] + (n_units_mh, )
def test_exceptions(): """ Text for expected exceptions """ # Inputs shape input_shape_3d = (56, 10, 30) n_units_mh = 99 n_heads = 4 inputs_3d = tf.random.normal(shape=input_shape_3d) multi_headed_self_attention = MultiHeadedSelfAttention( n_heads=n_heads, n_units=n_units_mh, name='mh_self_attention') # n_units % n_heads != 0, not divisible with pytest.raises(ValueError) as excinfo: multi_headed_self_attention(inputs_3d) assert 'n_units must be divisible by n_heads' in str(excinfo.value)
def test_masking(): """ Test masking support """ # Input input_shape_3d = (56, 10, 30) inputs_3d = tf.random.normal(shape=input_shape_3d) mask = tf.less(tf.reduce_sum(tf.reduce_sum(inputs_3d, axis=-1), axis=-1), 0) masked_input = tf.where(mask, tf.zeros_like(inputs_3d), inputs_3d) # Layers masking_layer = tf.keras.layers.Masking(mask_value=0., input_shape=input_shape_3d[1:]) multi_headed_self_attention = MultiHeadedSelfAttention( n_heads=3, name='mh_self_attention') result = multi_headed_self_attention(masking_layer(masked_input)) assert result.shape == input_shape_3d