Exemple #1
0
def test_viterbi_decode(dtype):
    inputs = np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=dtype)
    transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=dtype)
    sequence_lengths = np.array(3, dtype=np.int32)
    num_words = inputs.shape[0]
    num_tags = inputs.shape[1]

    all_sequence_scores = []
    all_sequences = []

    # Compare the dynamic program with brute force computation.
    for tag_indices in itertools.product(range(num_tags), repeat=sequence_lengths):
        tag_indices = list(tag_indices)
        tag_indices.extend([0] * (num_words - sequence_lengths))
        all_sequences.append(tag_indices)
        sequence_score = text.crf_sequence_score(
            inputs=tf.expand_dims(inputs, 0),
            tag_indices=tf.expand_dims(tag_indices, 0),
            sequence_lengths=tf.expand_dims(sequence_lengths, 0),
            transition_params=tf.constant(transition_params),
        )
        sequence_score = tf.squeeze(sequence_score, [0])
        all_sequence_scores.append(sequence_score)

    expected_max_sequence_index = np.argmax(all_sequence_scores)
    expected_max_sequence = all_sequences[expected_max_sequence_index]
    expected_max_score = all_sequence_scores[expected_max_sequence_index]

    actual_max_sequence, actual_max_score = text.viterbi_decode(
        inputs[:sequence_lengths], transition_params
    )

    test_utils.assert_allclose_according_to_type(actual_max_score, expected_max_score)
    assert actual_max_sequence == expected_max_sequence[:sequence_lengths]
Exemple #2
0
    def testViterbiDecode(self):
        inputs = np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]],
                          dtype=np.float32)
        transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]],
                                     dtype=np.float32)
        sequence_lengths = np.array(3, dtype=np.int32)
        # TODO: https://github.com/PyCQA/pylint/issues/3139
        # pylint: disable=E1136
        num_words = inputs.shape[0]
        num_tags = inputs.shape[1]
        # pylint: enable=E1136

        all_sequence_scores = []
        all_sequences = []

        # Compare the dynamic program with brute force computation.
        for tag_indices in itertools.product(range(num_tags),
                                             repeat=sequence_lengths):
            tag_indices = list(tag_indices)
            tag_indices.extend([0] * (num_words - sequence_lengths))
            all_sequences.append(tag_indices)
            sequence_score = text.crf_sequence_score(
                inputs=tf.expand_dims(inputs, 0),
                tag_indices=tf.expand_dims(tag_indices, 0),
                sequence_lengths=tf.expand_dims(sequence_lengths, 0),
                transition_params=tf.constant(transition_params),
            )
            sequence_score = tf.squeeze(sequence_score, [0])
            all_sequence_scores.append(sequence_score)

        tf_all_sequence_scores = self.evaluate(all_sequence_scores)

        expected_max_sequence_index = np.argmax(tf_all_sequence_scores)
        expected_max_sequence = all_sequences[expected_max_sequence_index]
        expected_max_score = tf_all_sequence_scores[
            expected_max_sequence_index]

        actual_max_sequence, actual_max_score = text.viterbi_decode(
            inputs[:sequence_lengths], transition_params)

        self.assertAllClose(actual_max_score, expected_max_score)
        self.assertEqual(actual_max_sequence,
                         expected_max_sequence[:sequence_lengths])