def testCrfLogNorm(self): inputs = np.array( [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) transition_params = np.array( [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) num_words = inputs.shape[0] num_tags = inputs.shape[1] sequence_lengths = np.array(3, dtype=np.int32) with self.test_session() as sess: all_sequence_scores = [] # Compare the dynamic program with brute force computation. for tag_indices in itertools.product( range(num_tags), repeat=sequence_lengths): tag_indices = list(tag_indices) tag_indices.extend([0] * (num_words - sequence_lengths)) all_sequence_scores.append( crf.crf_sequence_score( inputs=array_ops.expand_dims(inputs, 0), tag_indices=array_ops.expand_dims(tag_indices, 0), sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), transition_params=constant_op.constant(transition_params))) brute_force_log_norm = math_ops.reduce_logsumexp(all_sequence_scores) log_norm = crf.crf_log_norm( inputs=array_ops.expand_dims(inputs, 0), sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), transition_params=constant_op.constant(transition_params)) log_norm = array_ops.squeeze(log_norm, [0]) tf_brute_force_log_norm, tf_log_norm = sess.run( [brute_force_log_norm, log_norm]) self.assertAllClose(tf_log_norm, tf_brute_force_log_norm)
def testCrfLogNorm(self): transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) # Test both the length-1 and regular cases. sequence_lengths_list = [ np.array(3, dtype=np.int32), np.array(1, dtype=np.int32) ] inputs_list = [ np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32), np.array([[3, -1, 3]], dtype=np.float32), ] tag_indices_list = [ np.array([1, 2, 1, 0], dtype=np.int32), np.array([2], dtype=np.int32) ] for sequence_lengths, inputs, tag_indices in zip( sequence_lengths_list, inputs_list, tag_indices_list): num_words = inputs.shape[0] num_tags = inputs.shape[1] with self.test_session() as sess: all_sequence_scores = [] # Compare the dynamic program with brute force computation. for tag_indices in itertools.product(range(num_tags), repeat=sequence_lengths): tag_indices = list(tag_indices) tag_indices.extend([0] * (num_words - sequence_lengths)) all_sequence_scores.append( crf.crf_sequence_score( inputs=array_ops.expand_dims(inputs, 0), tag_indices=array_ops.expand_dims(tag_indices, 0), sequence_lengths=array_ops.expand_dims( sequence_lengths, 0), transition_params=constant_op.constant( transition_params))) brute_force_log_norm = math_ops.reduce_logsumexp( all_sequence_scores) log_norm = crf.crf_log_norm( inputs=array_ops.expand_dims(inputs, 0), sequence_lengths=array_ops.expand_dims( sequence_lengths, 0), transition_params=constant_op.constant(transition_params)) log_norm = array_ops.squeeze(log_norm, [0]) tf_brute_force_log_norm, tf_log_norm = sess.run( [brute_force_log_norm, log_norm]) self.assertAllClose(tf_log_norm, tf_brute_force_log_norm)
def testCrfLogNormZeroSeqLength(self): """ Test `crf_log_norm` when `sequence_lengths` contains one or more zeros. """ with self.test_session() as sess: inputs = constant_op.constant(np.ones([2, 10, 5], dtype=np.float32)) transition_params = constant_op.constant(np.ones([5, 5], dtype=np.float32)) sequence_lengths = constant_op.constant(np.zeros([2], dtype=np.int32)) expected_log_norm = np.zeros([2], dtype=np.float32) log_norm = crf.crf_log_norm(inputs, sequence_lengths, transition_params) tf_log_norm = sess.run(log_norm) self.assertAllClose(tf_log_norm, expected_log_norm)
def testCrfLogNormZeroSeqLength(self): """ Test `crf_log_norm` when `sequence_lengths` contains one or more zeros. """ with self.test_session() as sess: inputs = constant_op.constant(np.ones([2, 10, 5], dtype=np.float32)) transition_params = constant_op.constant( np.ones([5, 5], dtype=np.float32)) sequence_lengths = constant_op.constant( np.zeros([2], dtype=np.int32)) expected_log_norm = np.zeros([2], dtype=np.float32) log_norm = crf.crf_log_norm(inputs, sequence_lengths, transition_params) tf_log_norm = sess.run(log_norm) self.assertAllClose(tf_log_norm, expected_log_norm)