def test_hourglass_lm_autoregressive_property(self): d_model = 8 vocab_size = 26 model_single_stage = hourglass.HourglassLM( vocab_size, hierarchy='2@4', vanilla_layers=(1, 1), d_model=d_model, d_ff=d_model, n_heads=2, ) model_multi_stage = hourglass.HourglassLM( vocab_size, hierarchy='2@3 2@6 2@3', vanilla_layers=(1, 1), d_model=d_model, d_ff=d_model, n_heads=2, ) input_shape = (1, 12) self._test_autoregressive_property(model_single_stage, input_shape, output_vocab_size=vocab_size) self._test_autoregressive_property(model_multi_stage, input_shape, output_vocab_size=vocab_size)
def test_lsh_attention_in_vanilla(self): d_model = 16 vocab_size = 7 gin.bind_parameter( 'PureLSHSelfAttentionWrapper.pure_lsh_implementation', tl.PureLSHSelfAttention) gin.bind_parameter('PureLSHSelfAttention.chunk_len', 2) model = hourglass.HourglassLM( vocab_size, hierarchy='2@3', vanilla_layers=(1, 1), d_model=d_model, d_ff=d_model, n_heads=2, vanilla_attn_type=tl.PureLSHSelfAttentionWrapper, downsampling_fn=resampling.LinearPooling, upsampling_fn=resampling.LinearUpsampling, ) batch_size, seq_len = 3, 12 self._check_forward_shape(model, input_shape=(batch_size, seq_len), output_vocab_size=vocab_size)
def test_hourglass_lm_forward_shape(self): d_model = 16 vocab_size = 7 model = hourglass.HourglassLM( vocab_size, hierarchy='2@3 2@6 2@3', vanilla_layers=(1, 1), d_model=d_model, d_ff=d_model, n_heads=2, ) batch_size, seq_len = 3, 24 self._check_forward_shape(model, input_shape=(batch_size, seq_len), output_vocab_size=vocab_size)