def test_viterbi_decode(dtype): inputs = np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=dtype) transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=dtype) sequence_lengths = np.array(3, dtype=np.int32) num_words = inputs.shape[0] num_tags = inputs.shape[1] all_sequence_scores = [] all_sequences = [] # Compare the dynamic program with brute force computation. for tag_indices in itertools.product(range(num_tags), repeat=sequence_lengths): tag_indices = list(tag_indices) tag_indices.extend([0] * (num_words - sequence_lengths)) all_sequences.append(tag_indices) sequence_score = text.crf_sequence_score( inputs=tf.expand_dims(inputs, 0), tag_indices=tf.expand_dims(tag_indices, 0), sequence_lengths=tf.expand_dims(sequence_lengths, 0), transition_params=tf.constant(transition_params), ) sequence_score = tf.squeeze(sequence_score, [0]) all_sequence_scores.append(sequence_score) expected_max_sequence_index = np.argmax(all_sequence_scores) expected_max_sequence = all_sequences[expected_max_sequence_index] expected_max_score = all_sequence_scores[expected_max_sequence_index] actual_max_sequence, actual_max_score = text.viterbi_decode( inputs[:sequence_lengths], transition_params ) test_utils.assert_allclose_according_to_type(actual_max_score, expected_max_score) assert actual_max_sequence == expected_max_sequence[:sequence_lengths]
def testViterbiDecode(self): inputs = np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) sequence_lengths = np.array(3, dtype=np.int32) # TODO: https://github.com/PyCQA/pylint/issues/3139 # pylint: disable=E1136 num_words = inputs.shape[0] num_tags = inputs.shape[1] # pylint: enable=E1136 all_sequence_scores = [] all_sequences = [] # Compare the dynamic program with brute force computation. for tag_indices in itertools.product(range(num_tags), repeat=sequence_lengths): tag_indices = list(tag_indices) tag_indices.extend([0] * (num_words - sequence_lengths)) all_sequences.append(tag_indices) sequence_score = text.crf_sequence_score( inputs=tf.expand_dims(inputs, 0), tag_indices=tf.expand_dims(tag_indices, 0), sequence_lengths=tf.expand_dims(sequence_lengths, 0), transition_params=tf.constant(transition_params), ) sequence_score = tf.squeeze(sequence_score, [0]) all_sequence_scores.append(sequence_score) tf_all_sequence_scores = self.evaluate(all_sequence_scores) expected_max_sequence_index = np.argmax(tf_all_sequence_scores) expected_max_sequence = all_sequences[expected_max_sequence_index] expected_max_score = tf_all_sequence_scores[ expected_max_sequence_index] actual_max_sequence, actual_max_score = text.viterbi_decode( inputs[:sequence_lengths], transition_params) self.assertAllClose(actual_max_score, expected_max_score) self.assertEqual(actual_max_sequence, expected_max_sequence[:sequence_lengths])