예제 #1
0
    def test_hourglass_lm_autoregressive_property(self):
        d_model = 8
        vocab_size = 26

        model_single_stage = hourglass.HourglassLM(
            vocab_size,
            hierarchy='2@4',
            vanilla_layers=(1, 1),
            d_model=d_model,
            d_ff=d_model,
            n_heads=2,
        )

        model_multi_stage = hourglass.HourglassLM(
            vocab_size,
            hierarchy='2@3 2@6 2@3',
            vanilla_layers=(1, 1),
            d_model=d_model,
            d_ff=d_model,
            n_heads=2,
        )

        input_shape = (1, 12)
        self._test_autoregressive_property(model_single_stage,
                                           input_shape,
                                           output_vocab_size=vocab_size)
        self._test_autoregressive_property(model_multi_stage,
                                           input_shape,
                                           output_vocab_size=vocab_size)
예제 #2
0
    def test_lsh_attention_in_vanilla(self):
        d_model = 16
        vocab_size = 7

        gin.bind_parameter(
            'PureLSHSelfAttentionWrapper.pure_lsh_implementation',
            tl.PureLSHSelfAttention)
        gin.bind_parameter('PureLSHSelfAttention.chunk_len', 2)

        model = hourglass.HourglassLM(
            vocab_size,
            hierarchy='2@3',
            vanilla_layers=(1, 1),
            d_model=d_model,
            d_ff=d_model,
            n_heads=2,
            vanilla_attn_type=tl.PureLSHSelfAttentionWrapper,
            downsampling_fn=resampling.LinearPooling,
            upsampling_fn=resampling.LinearUpsampling,
        )

        batch_size, seq_len = 3, 12
        self._check_forward_shape(model,
                                  input_shape=(batch_size, seq_len),
                                  output_vocab_size=vocab_size)
예제 #3
0
    def test_hourglass_lm_forward_shape(self):
        d_model = 16
        vocab_size = 7
        model = hourglass.HourglassLM(
            vocab_size,
            hierarchy='2@3 2@6 2@3',
            vanilla_layers=(1, 1),
            d_model=d_model,
            d_ff=d_model,
            n_heads=2,
        )

        batch_size, seq_len = 3, 24
        self._check_forward_shape(model,
                                  input_shape=(batch_size, seq_len),
                                  output_vocab_size=vocab_size)