Exemple #1
0
  def testCrfLogNorm(self):
    inputs = np.array(
        [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32)
    transition_params = np.array(
        [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
    num_words = inputs.shape[0]
    num_tags = inputs.shape[1]
    sequence_lengths = np.array(3, dtype=np.int32)
    with self.test_session() as sess:
      all_sequence_scores = []

      # Compare the dynamic program with brute force computation.
      for tag_indices in itertools.product(
          range(num_tags), repeat=sequence_lengths):
        tag_indices = list(tag_indices)
        tag_indices.extend([0] * (num_words - sequence_lengths))
        all_sequence_scores.append(
            crf.crf_sequence_score(
                inputs=array_ops.expand_dims(inputs, 0),
                tag_indices=array_ops.expand_dims(tag_indices, 0),
                sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
                transition_params=constant_op.constant(transition_params)))

      brute_force_log_norm = math_ops.reduce_logsumexp(all_sequence_scores)
      log_norm = crf.crf_log_norm(
          inputs=array_ops.expand_dims(inputs, 0),
          sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
          transition_params=constant_op.constant(transition_params))
      log_norm = array_ops.squeeze(log_norm, [0])
      tf_brute_force_log_norm, tf_log_norm = sess.run(
          [brute_force_log_norm, log_norm])

      self.assertAllClose(tf_log_norm, tf_brute_force_log_norm)
Exemple #2
0
    def testCrfLogNorm(self):
        transition_params = np.array([[-3, 5, -2], [3, 4, 1], [1, 2, 1]],
                                     dtype=np.float32)
        # Test both the length-1 and regular cases.
        sequence_lengths_list = [
            np.array(3, dtype=np.int32),
            np.array(1, dtype=np.int32)
        ]
        inputs_list = [
            np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]],
                     dtype=np.float32),
            np.array([[3, -1, 3]], dtype=np.float32),
        ]
        tag_indices_list = [
            np.array([1, 2, 1, 0], dtype=np.int32),
            np.array([2], dtype=np.int32)
        ]

        for sequence_lengths, inputs, tag_indices in zip(
                sequence_lengths_list, inputs_list, tag_indices_list):
            num_words = inputs.shape[0]
            num_tags = inputs.shape[1]
            with self.test_session() as sess:
                all_sequence_scores = []

                # Compare the dynamic program with brute force computation.
                for tag_indices in itertools.product(range(num_tags),
                                                     repeat=sequence_lengths):
                    tag_indices = list(tag_indices)
                    tag_indices.extend([0] * (num_words - sequence_lengths))
                    all_sequence_scores.append(
                        crf.crf_sequence_score(
                            inputs=array_ops.expand_dims(inputs, 0),
                            tag_indices=array_ops.expand_dims(tag_indices, 0),
                            sequence_lengths=array_ops.expand_dims(
                                sequence_lengths, 0),
                            transition_params=constant_op.constant(
                                transition_params)))

                brute_force_log_norm = math_ops.reduce_logsumexp(
                    all_sequence_scores)
                log_norm = crf.crf_log_norm(
                    inputs=array_ops.expand_dims(inputs, 0),
                    sequence_lengths=array_ops.expand_dims(
                        sequence_lengths, 0),
                    transition_params=constant_op.constant(transition_params))
                log_norm = array_ops.squeeze(log_norm, [0])
                tf_brute_force_log_norm, tf_log_norm = sess.run(
                    [brute_force_log_norm, log_norm])

                self.assertAllClose(tf_log_norm, tf_brute_force_log_norm)
 def testCrfLogNormZeroSeqLength(self):
   """
   Test `crf_log_norm` when `sequence_lengths` contains one or more zeros.
   """
   with self.test_session() as sess:
     inputs = constant_op.constant(np.ones([2, 10, 5],
                                           dtype=np.float32))
     transition_params = constant_op.constant(np.ones([5, 5],
                                                      dtype=np.float32))
     sequence_lengths = constant_op.constant(np.zeros([2],
                                                      dtype=np.int32))
     expected_log_norm = np.zeros([2], dtype=np.float32)
     log_norm = crf.crf_log_norm(inputs, sequence_lengths, transition_params)
     tf_log_norm = sess.run(log_norm)
     self.assertAllClose(tf_log_norm, expected_log_norm)
Exemple #4
0
 def testCrfLogNormZeroSeqLength(self):
     """
 Test `crf_log_norm` when `sequence_lengths` contains one or more zeros.
 """
     with self.test_session() as sess:
         inputs = constant_op.constant(np.ones([2, 10, 5],
                                               dtype=np.float32))
         transition_params = constant_op.constant(
             np.ones([5, 5], dtype=np.float32))
         sequence_lengths = constant_op.constant(
             np.zeros([2], dtype=np.int32))
         expected_log_norm = np.zeros([2], dtype=np.float32)
         log_norm = crf.crf_log_norm(inputs, sequence_lengths,
                                     transition_params)
         tf_log_norm = sess.run(log_norm)
         self.assertAllClose(tf_log_norm, expected_log_norm)