def test_no_segment_ids(self): segment_matrix = xlnet_base._compute_segment_matrix( segment_ids=None, memory_length=2, batch_size=1, use_cls_mask=False) self.assertIsNone(segment_matrix)
def test_basic(self): batch_size = 1 memory_length = 0 segment_ids = np.array([[1, 1, 2, 1]]) expected_segment_matrix = np.array([[[False, False, True, False], [False, False, True, False], [True, True, False, True], [False, False, True, False]]]) segment_matrix = xlnet_base._compute_segment_matrix( segment_ids=segment_ids, memory_length=memory_length, batch_size=batch_size, use_cls_mask=False) self.assertAllClose(segment_matrix, expected_segment_matrix)
def test_basic_with_memory(self): batch_size = 1 memory_length = 1 segment_ids = np.array([[1, 1, 2, 1]]) expected_segment_matrix = np.array([[[True, False, False, True, False], [True, False, False, True, False], [True, True, True, False, True], [True, False, False, True, False]]]).astype(int) segment_matrix = tf.cast(xlnet_base._compute_segment_matrix( segment_ids=segment_ids, memory_length=memory_length, batch_size=batch_size, use_cls_mask=False), dtype=tf.uint8) self.assertAllClose(segment_matrix, expected_segment_matrix)
def dont_test_basic_with_class_mask(self): # TODO(allencwang) - this test should pass but illustrates the legacy issue # of using class mask. Enable once addressed. batch_size = 1 memory_length = 0 segment_ids = np.array([[1, 1, 2, 1]]) expected_segment_matrix = np.array([[[False, False, True, False], [False, False, True, False], [True, True, False, True], [False, False, True, False]]]).astype(int) segment_matrix = tf.cast(xlnet_base._compute_segment_matrix( segment_ids=segment_ids, memory_length=memory_length, batch_size=batch_size, use_cls_mask=True), dtype=tf.uint8) self.assertAllClose(segment_matrix, expected_segment_matrix)