Ejemplo n.º 1
0
def test_crf_multi_tag_sequence_score(dtype):
    transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=dtype)
    # Test both the length-1 and regular cases.
    sequence_lengths_list = [
        np.array(3, dtype=np.int32),
        np.array(1, dtype=np.int32),
    ]
    inputs_list = [
        np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=dtype),
        np.array([[4, 5, -3]], dtype=dtype),
    ]
    tag_bitmap_list = [
        np.array(
            [
                [True, True, False],
                [True, False, True],
                [False, True, True],
                [True, False, True],
            ],
            dtype=np.bool,
        ),
        np.array([[True, True, False]], dtype=np.bool),
    ]
    for sequence_lengths, inputs, tag_bitmap in zip(
        sequence_lengths_list, inputs_list, tag_bitmap_list
    ):
        sequence_score = text.crf_multitag_sequence_score(
            inputs=tf.expand_dims(inputs, 0),
            tag_bitmap=tf.expand_dims(tag_bitmap, 0),
            sequence_lengths=tf.expand_dims(sequence_lengths, 0),
            transition_params=tf.constant(transition_params),
        )
        sequence_score = tf.squeeze(sequence_score, [0])
        all_indices_list = [
            single_index_bitmap.nonzero()[0]
            for single_index_bitmap in tag_bitmap[:sequence_lengths]
        ]
        expected_sequence_scores = [
            calculate_sequence_score(
                inputs, transition_params, indices, sequence_lengths
            )
            for indices in itertools.product(*all_indices_list)
        ]
        expected_log_sum_exp_sequence_scores = np.logaddexp.reduce(
            expected_sequence_scores
        )
        test_utils.assert_allclose_according_to_type(
            sequence_score, expected_log_sum_exp_sequence_scores
        )
Ejemplo n.º 2
0
 def testCrfMultiTagSequenceScore(self):
     transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]],
                                  dtype=np.float32)
     # Test both the length-1 and regular cases.
     sequence_lengths_list = [
         np.array(3, dtype=np.int32),
         np.array(1, dtype=np.int32)
     ]
     inputs_list = [
         np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]],
                  dtype=np.float32),
         np.array([[4, 5, -3]], dtype=np.float32),
     ]
     tag_bitmap_list = [
         np.array([[True, True, False], [True, False, True],
                   [False, True, True], [True, False, True]],
                  dtype=np.bool),
         np.array([[True, True, False]], dtype=np.bool)
     ]
     for sequence_lengths, inputs, tag_bitmap in zip(
             sequence_lengths_list, inputs_list, tag_bitmap_list):
         sequence_score = text.crf_multitag_sequence_score(
             inputs=tf.expand_dims(inputs, 0),
             tag_bitmap=tf.expand_dims(tag_bitmap, 0),
             sequence_lengths=tf.expand_dims(sequence_lengths, 0),
             transition_params=tf.constant(transition_params))
         sequence_score = tf.squeeze(sequence_score, [0])
         tf_sum_sequence_score = self.evaluate(sequence_score)
         all_indices_list = [
             single_index_bitmap.nonzero()[0]
             for single_index_bitmap in tag_bitmap[:sequence_lengths]
         ]
         expected_sequence_scores = [
             self.calculateSequenceScore(inputs, transition_params, indices,
                                         sequence_lengths)
             for indices in itertools.product(*all_indices_list)
         ]
         expected_log_sum_exp_sequence_scores = np.logaddexp.reduce(
             expected_sequence_scores)
         self.assertAllClose(tf_sum_sequence_score,
                             expected_log_sum_exp_sequence_scores)