def test_no_segment_ids(self):
   segment_matrix = xlnet_base._compute_segment_matrix(
       segment_ids=None,
       memory_length=2,
       batch_size=1,
       use_cls_mask=False)
   self.assertIsNone(segment_matrix)
 def test_basic(self):
     batch_size = 1
     memory_length = 0
     segment_ids = np.array([[1, 1, 2, 1]])
     expected_segment_matrix = np.array([[[False, False, True, False],
                                          [False, False, True, False],
                                          [True, True, False, True],
                                          [False, False, True, False]]])
     segment_matrix = xlnet_base._compute_segment_matrix(
         segment_ids=segment_ids,
         memory_length=memory_length,
         batch_size=batch_size,
         use_cls_mask=False)
     self.assertAllClose(segment_matrix, expected_segment_matrix)
 def test_basic_with_memory(self):
     batch_size = 1
     memory_length = 1
     segment_ids = np.array([[1, 1, 2, 1]])
     expected_segment_matrix = np.array([[[True, False, False, True, False],
                                          [True, False, False, True, False],
                                          [True, True, True, False, True],
                                          [True, False, False, True,
                                           False]]]).astype(int)
     segment_matrix = tf.cast(xlnet_base._compute_segment_matrix(
         segment_ids=segment_ids,
         memory_length=memory_length,
         batch_size=batch_size,
         use_cls_mask=False),
                              dtype=tf.uint8)
     self.assertAllClose(segment_matrix, expected_segment_matrix)
 def dont_test_basic_with_class_mask(self):
     # TODO(allencwang) - this test should pass but illustrates the legacy issue
     # of using class mask. Enable once addressed.
     batch_size = 1
     memory_length = 0
     segment_ids = np.array([[1, 1, 2, 1]])
     expected_segment_matrix = np.array([[[False, False, True, False],
                                          [False, False, True, False],
                                          [True, True, False, True],
                                          [False, False, True,
                                           False]]]).astype(int)
     segment_matrix = tf.cast(xlnet_base._compute_segment_matrix(
         segment_ids=segment_ids,
         memory_length=memory_length,
         batch_size=batch_size,
         use_cls_mask=True),
                              dtype=tf.uint8)
     self.assertAllClose(segment_matrix, expected_segment_matrix)