示例#1
0
def test_crf_log_norm_zero_seq_length(dtype):
    """Test `crf_log_norm` when `sequence_lengths` contains one or more
    zeros."""
    inputs = tf.constant(np.ones([2, 10, 5], dtype=dtype))
    transition_params = tf.constant(np.ones([5, 5], dtype=dtype))
    sequence_lengths = tf.constant(np.zeros([2], dtype=np.int32))
    expected_log_norm = np.zeros([2], dtype=dtype)
    log_norm = text.crf_log_norm(inputs, sequence_lengths, transition_params)
    test_utils.assert_allclose_according_to_type(log_norm, expected_log_norm)
示例#2
0
 def testCrfLogNormZeroSeqLength(self):
     """Test `crf_log_norm` when `sequence_lengths` contains one or more
     zeros."""
     inputs = tf.constant(np.ones([2, 10, 5], dtype=np.float32))
     transition_params = tf.constant(np.ones([5, 5], dtype=np.float32))
     sequence_lengths = tf.constant(np.zeros([2], dtype=np.int32))
     expected_log_norm = np.zeros([2], dtype=np.float32)
     log_norm = text.crf_log_norm(inputs, sequence_lengths,
                                  transition_params)
     tf_log_norm = self.evaluate(log_norm)
     self.assertAllClose(tf_log_norm, expected_log_norm)
示例#3
0
    def testCrfLogNorm(self):
        transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]],
                                     dtype=np.float32)
        # Test both the length-1 and regular cases.
        sequence_lengths_list = [
            np.array(3, dtype=np.int32),
            np.array(1, dtype=np.int64),
        ]
        inputs_list = [
            np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]],
                     dtype=np.float32),
            np.array([[3, -1, 3]], dtype=np.float32),
        ]
        tag_indices_list = [
            np.array([1, 2, 1, 0], dtype=np.int32),
            np.array([2], dtype=np.int32),
        ]

        for sequence_lengths, inputs, tag_indices in zip(
                sequence_lengths_list, inputs_list, tag_indices_list):
            num_words = inputs.shape[0]
            num_tags = inputs.shape[1]
            all_sequence_scores = []

            # Compare the dynamic program with brute force computation.
            for tag_indices in itertools.product(range(num_tags),
                                                 repeat=sequence_lengths):
                tag_indices = list(tag_indices)
                tag_indices.extend([0] * (num_words - sequence_lengths))
                all_sequence_scores.append(
                    text.crf_sequence_score(
                        inputs=tf.expand_dims(inputs, 0),
                        tag_indices=tf.expand_dims(tag_indices, 0),
                        sequence_lengths=tf.expand_dims(sequence_lengths, 0),
                        transition_params=tf.constant(transition_params),
                    ))

            brute_force_log_norm = tf.reduce_logsumexp(all_sequence_scores)
            log_norm = text.crf_log_norm(
                inputs=tf.expand_dims(inputs, 0),
                sequence_lengths=tf.expand_dims(sequence_lengths, 0),
                transition_params=tf.constant(transition_params),
            )
            log_norm = tf.squeeze(log_norm, [0])
            tf_brute_force_log_norm, tf_log_norm = self.evaluate(
                [brute_force_log_norm, log_norm])

            self.assertAllClose(tf_log_norm, tf_brute_force_log_norm)