def test_casual_attention_mask_with_no_memory(self):
        seq_length, memory_length = 3, 0
        causal_attention_mask = xlnet_base._create_causal_attention_mask(
            seq_length=seq_length, memory_length=memory_length)

        expected_output = np.array([[1, 0, 0], [1, 1, 0], [1, 1, 1]])
        self.assertAllClose(causal_attention_mask, expected_output)
    def test_causal_attention_mask_with_same_length(self):
        seq_length, memory_length = 3, 2
        causal_attention_mask = xlnet_base._create_causal_attention_mask(
            seq_length=seq_length,
            memory_length=memory_length,
            same_length=True)

        expected_output = np.array([[1, 1, 1, 0, 0], [0, 1, 1, 1, 0],
                                    [0, 0, 1, 1, 1]])
        self.assertAllClose(causal_attention_mask, expected_output)