def test_crf_log_norm_zero_seq_length(dtype): """Test `crf_log_norm` when `sequence_lengths` contains one or more zeros.""" inputs = tf.constant(np.ones([2, 10, 5], dtype=dtype)) transition_params = tf.constant(np.ones([5, 5], dtype=dtype)) sequence_lengths = tf.constant(np.zeros([2], dtype=np.int32)) expected_log_norm = np.zeros([2], dtype=dtype) log_norm = text.crf_log_norm(inputs, sequence_lengths, transition_params) test_utils.assert_allclose_according_to_type(log_norm, expected_log_norm)
def testCrfLogNormZeroSeqLength(self): """Test `crf_log_norm` when `sequence_lengths` contains one or more zeros.""" inputs = tf.constant(np.ones([2, 10, 5], dtype=np.float32)) transition_params = tf.constant(np.ones([5, 5], dtype=np.float32)) sequence_lengths = tf.constant(np.zeros([2], dtype=np.int32)) expected_log_norm = np.zeros([2], dtype=np.float32) log_norm = text.crf_log_norm(inputs, sequence_lengths, transition_params) tf_log_norm = self.evaluate(log_norm) self.assertAllClose(tf_log_norm, expected_log_norm)
def testCrfLogNorm(self): transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) # Test both the length-1 and regular cases. sequence_lengths_list = [ np.array(3, dtype=np.int32), np.array(1, dtype=np.int64), ] inputs_list = [ np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32), np.array([[3, -1, 3]], dtype=np.float32), ] tag_indices_list = [ np.array([1, 2, 1, 0], dtype=np.int32), np.array([2], dtype=np.int32), ] for sequence_lengths, inputs, tag_indices in zip( sequence_lengths_list, inputs_list, tag_indices_list): num_words = inputs.shape[0] num_tags = inputs.shape[1] all_sequence_scores = [] # Compare the dynamic program with brute force computation. for tag_indices in itertools.product(range(num_tags), repeat=sequence_lengths): tag_indices = list(tag_indices) tag_indices.extend([0] * (num_words - sequence_lengths)) all_sequence_scores.append( text.crf_sequence_score( inputs=tf.expand_dims(inputs, 0), tag_indices=tf.expand_dims(tag_indices, 0), sequence_lengths=tf.expand_dims(sequence_lengths, 0), transition_params=tf.constant(transition_params), )) brute_force_log_norm = tf.reduce_logsumexp(all_sequence_scores) log_norm = text.crf_log_norm( inputs=tf.expand_dims(inputs, 0), sequence_lengths=tf.expand_dims(sequence_lengths, 0), transition_params=tf.constant(transition_params), ) log_norm = tf.squeeze(log_norm, [0]) tf_brute_force_log_norm, tf_log_norm = self.evaluate( [brute_force_log_norm, log_norm]) self.assertAllClose(tf_log_norm, tf_brute_force_log_norm)