예제 #1
0
def crf_decode(potentials, transition_params, sequence_length):
    """Decode the highest scoring sequence of tags in TensorFlow.

  This is a function for tensor.

  Args:
    potentials: A [batch_size, max_seq_len, num_tags] tensor of
              unary potentials.
    transition_params: A [num_tags, num_tags] matrix of
              binary potentials.
    sequence_length: A [batch_size] vector of true sequence lengths.

  Returns:
    decode_tags: A [batch_size, max_seq_len] tensor, with dtype tf.int32.
                Contains the highest scoring tag indices.
    best_score: A [batch_size] tensor, containing the score of decode_tags.
  """
    # For simplicity, in shape comments, denote:
    # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
    num_tags = potentials.get_shape()[2].value

    # Computes forward decoding. Get last score and backpointers.
    crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
    initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1])
    initial_state = array_ops.squeeze(initial_state, axis=[1])  # [B, O]
    inputs = array_ops.slice(potentials, [0, 1, 0],
                             [-1, -1, -1])  # [B, T-1, O]
    backpointers, last_score = rnn.dynamic_rnn(
        crf_fwd_cell,
        inputs=inputs,
        sequence_length=sequence_length - 1,
        initial_state=initial_state,
        time_major=False,
        dtype=dtypes.int32)  # [B, T - 1, O], [B, O]
    backpointers = gen_array_ops.reverse_sequence(backpointers,
                                                  sequence_length - 1,
                                                  seq_dim=1)  # [B, T-1, O]

    # Computes backward decoding. Extract tag indices from backpointers.
    crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags)
    initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1),
                                  dtype=dtypes.int32)  # [B]
    initial_state = array_ops.expand_dims(initial_state, axis=-1)  # [B, 1]
    decode_tags, _ = rnn.dynamic_rnn(crf_bwd_cell,
                                     inputs=backpointers,
                                     sequence_length=sequence_length - 1,
                                     initial_state=initial_state,
                                     time_major=False,
                                     dtype=dtypes.int32)  # [B, T - 1, 1]
    decode_tags = array_ops.squeeze(decode_tags, axis=[2])  # [B, T - 1]
    decode_tags = array_ops.concat([initial_state, decode_tags],
                                   axis=1)  # [B, T]
    decode_tags = gen_array_ops.reverse_sequence(decode_tags,
                                                 sequence_length,
                                                 seq_dim=1)  # [B, T]

    best_score = math_ops.reduce_max(last_score, axis=1)  # [B]
    return decode_tags, best_score
예제 #2
0
파일: crf.py 프로젝트: SylChan/tensorflow
def crf_decode(potentials, transition_params, sequence_length):
  """Decode the highest scoring sequence of tags in TensorFlow.

  This is a function for tensor.

  Args:
    potentials: A [batch_size, max_seq_len, num_tags] tensor of
              unary potentials.
    transition_params: A [num_tags, num_tags] matrix of
              binary potentials.
    sequence_length: A [batch_size] vector of true sequence lengths.

  Returns:
    decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`.
                Contains the highest scoring tag indicies.
    best_score: A [batch_size] vector, containing the score of `decode_tags`.
  """
  # For simplicity, in shape comments, denote:
  # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
  num_tags = potentials.get_shape()[2].value

  # Computes forward decoding. Get last score and backpointers.
  crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
  initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1])
  initial_state = array_ops.squeeze(initial_state, axis=[1])      # [B, O]
  inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1])   # [B, T-1, O]
  backpointers, last_score = rnn.dynamic_rnn(
      crf_fwd_cell,
      inputs=inputs,
      sequence_length=sequence_length - 1,
      initial_state=initial_state,
      time_major=False,
      dtype=dtypes.int32)             # [B, T - 1, O], [B, O]
  backpointers = gen_array_ops.reverse_sequence(
      backpointers, sequence_length - 1, seq_dim=1)               # [B, T-1, O]

  # Computes backward decoding. Extract tag indices from backpointers.
  crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags)
  initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1),
                                dtype=dtypes.int32)               # [B]
  initial_state = array_ops.expand_dims(initial_state, axis=-1)   # [B, 1]
  decode_tags, _ = rnn.dynamic_rnn(
      crf_bwd_cell,
      inputs=backpointers,
      sequence_length=sequence_length - 1,
      initial_state=initial_state,
      time_major=False,
      dtype=dtypes.int32)           # [B, T - 1, 1]
  decode_tags = array_ops.squeeze(decode_tags, axis=[2])           # [B, T - 1]
  decode_tags = array_ops.concat([initial_state, decode_tags], axis=1)  # [B, T]
  decode_tags = gen_array_ops.reverse_sequence(
      decode_tags, sequence_length, seq_dim=1)                     # [B, T]

  best_score = math_ops.reduce_max(last_score, axis=1)             # [B]
  return decode_tags, best_score
예제 #3
0
    def _multi_seq_fn():
        """Decoding of highest scoring sequence."""

        # For simplicity, in shape comments, denote:
        # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
        num_tags = tensor_shape.dimension_value(potentials.shape[2])

        # Computes forward decoding. Get last score and backpointers.
        crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
        initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1])
        initial_state = array_ops.squeeze(initial_state, axis=[1])  # [B, O]
        inputs = array_ops.slice(potentials, [0, 1, 0],
                                 [-1, -1, -1])  # [B, T-1, O]
        # Sequence length is not allowed to be less than zero.
        sequence_length_less_one = math_ops.maximum(
            constant_op.constant(0, dtype=sequence_length.dtype),
            sequence_length - 1)
        backpointers, last_score = rnn.dynamic_rnn(  # [B, T - 1, O], [B, O]
            crf_fwd_cell,
            inputs=inputs,
            sequence_length=sequence_length_less_one,
            initial_state=initial_state,
            time_major=False,
            dtype=dtypes.int32)
        backpointers = gen_array_ops.reverse_sequence(  # [B, T - 1, O]
            backpointers,
            sequence_length_less_one,
            seq_dim=1)

        # Computes backward decoding. Extract tag indices from backpointers.
        crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags)
        initial_state = math_ops.cast(
            math_ops.argmax(last_score, axis=1),  # [B]
            dtype=dtypes.int32)
        initial_state = array_ops.expand_dims(initial_state, axis=-1)  # [B, 1]
        decode_tags, _ = rnn.dynamic_rnn(  # [B, T - 1, 1]
            crf_bwd_cell,
            inputs=backpointers,
            sequence_length=sequence_length_less_one,
            initial_state=initial_state,
            time_major=False,
            dtype=dtypes.int32)
        decode_tags = array_ops.squeeze(decode_tags, axis=[2])  # [B, T - 1]
        decode_tags = array_ops.concat(
            [initial_state, decode_tags],  # [B, T]
            axis=1)
        decode_tags = gen_array_ops.reverse_sequence(  # [B, T]
            decode_tags, sequence_length, seq_dim=1)

        best_score = math_ops.reduce_max(last_score, axis=1)  # [B]
        return decode_tags, best_score
예제 #4
0
  def _multi_seq_fn():
    """Decoding of highest scoring sequence."""

    # For simplicity, in shape comments, denote:
    # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
    num_tags = potentials.get_shape()[2].value

    # Computes forward decoding. Get last score and backpointers.
    crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
    initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1])
    initial_state = array_ops.squeeze(initial_state, axis=[1])  # [B, O]
    inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1])  # [B, T-1, O]
    # Sequence length is not allowed to be less than zero.
    sequence_length_less_one = math_ops.maximum(
        constant_op.constant(0, dtype=sequence_length.dtype),
        sequence_length - 1)
    backpointers, last_score = rnn.dynamic_rnn(  # [B, T - 1, O], [B, O]
        crf_fwd_cell,
        inputs=inputs,
        sequence_length=sequence_length_less_one,
        initial_state=initial_state,
        time_major=False,
        dtype=dtypes.int32)
    backpointers = gen_array_ops.reverse_sequence(  # [B, T - 1, O]
        backpointers, sequence_length_less_one, seq_dim=1)

    # Computes backward decoding. Extract tag indices from backpointers.
    crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags)
    initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1),  # [B]
                                  dtype=dtypes.int32)
    initial_state = array_ops.expand_dims(initial_state, axis=-1)  # [B, 1]
    decode_tags, _ = rnn.dynamic_rnn(  # [B, T - 1, 1]
        crf_bwd_cell,
        inputs=backpointers,
        sequence_length=sequence_length_less_one,
        initial_state=initial_state,
        time_major=False,
        dtype=dtypes.int32)
    decode_tags = array_ops.squeeze(decode_tags, axis=[2])  # [B, T - 1]
    decode_tags = array_ops.concat([initial_state, decode_tags],   # [B, T]
                                   axis=1)
    decode_tags = gen_array_ops.reverse_sequence(  # [B, T]
        decode_tags, sequence_length, seq_dim=1)

    best_score = math_ops.reduce_max(last_score, axis=1)  # [B]
    return decode_tags, best_score
예제 #5
0
 def crf_decode(self, potentials, seq_lens):
     crf_fwd_cell = CrfDecodeForwardRnnCell(self.transition_params)
     initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1])
     initial_state = array_ops.squeeze(initial_state, axis=[1])
     inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1])
     seq_len_less_one = math_ops.maximum(constant_op.constant(0, dtype=seq_lens.dtype), seq_lens - 1)
     backpointers, last_score = rnn.dynamic_rnn(crf_fwd_cell, inputs=inputs, initial_state=initial_state,
                                                sequence_length=seq_len_less_one, time_major=False,
                                                dtype=dtypes.int32)
     backpointers = gen_array_ops.reverse_sequence(backpointers, seq_len_less_one, seq_dim=1)
     crf_bwd_cell = CrfDecodeBackwardRnnCell(self.num_tags)
     initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1), dtype=dtypes.int32)
     initial_state = array_ops.expand_dims(initial_state, axis=-1)
     decode_tags, _ = rnn.dynamic_rnn(crf_bwd_cell, inputs=backpointers, sequence_length=seq_len_less_one,
                                      initial_state=initial_state, time_major=False, dtype=dtypes.int32)
     decode_tags = array_ops.squeeze(decode_tags, axis=[2])
     decode_tags = array_ops.concat([initial_state, decode_tags], axis=1)
     decode_tags = gen_array_ops.reverse_sequence(decode_tags, seq_lens, seq_dim=1)
     best_score = math_ops.reduce_max(last_score, axis=1)
     return decode_tags, best_score
예제 #6
0
    def _multi_seq_fn():
        # Split up the first and rest of the inputs in preparation for the forward
        # algorithm.
        batch_size = array_ops.shape(inputs)[0]
        num_tags = array_ops.shape(inputs)[2]

        first_input = array_ops.slice(inputs, [0, 0, 0], [-1, 1, -1])
        first_input = array_ops.squeeze(first_input, [1])
        rest_of_input = array_ops.slice(inputs, [0, 1, 0], [-1, -1, -1])

        # Compute the alpha values in the forward algorithm
        forward_cell = CrfForwardRnnCell(transition_params)
        alphas_seq, alphas = rnn.dynamic_rnn(cell=forward_cell,
                                             inputs=rest_of_input,
                                             sequence_length=sequence_lengths -
                                             1,
                                             initial_state=first_input,
                                             dtype=dtypes.float32)
        # Get all alphas in each time steps
        alphas_seq = tf.concat(
            [tf.expand_dims(first_input, axis=1), alphas_seq], axis=1)

        # Compute the betas values in the backward algorithm
        first_input = tf.constant(
            0.0, shape=[1, 1])  # as we use log, so 0.0 for beta initialization
        first_input = tf.tile(first_input, multiples=[batch_size, num_tags])

        # reverse the sequence of inputs in forward algorithm for backward algorithm
        rest_of_input = gen_array_ops.reverse_sequence(rest_of_input,
                                                       sequence_lengths - 1,
                                                       seq_dim=1)

        # transpose transition parameters for backward algorithm
        backward_cell = CrfBackwardRnnCell(
            tf.transpose(transition_params, perm=[1, 0]))
        betas_seq, betas = rnn.dynamic_rnn(cell=backward_cell,
                                           inputs=rest_of_input,
                                           sequence_length=sequence_lengths -
                                           1,
                                           initial_state=first_input,
                                           dtype=dtypes.float32)

        betas_seq = tf.concat([tf.expand_dims(first_input, axis=1), betas_seq],
                              axis=1)

        # reverse betas that follows same index as alphas
        betas_seq = tf.reverse_sequence(betas_seq, sequence_lengths, seq_dim=1)

        # crf log norm
        log_norm = math_ops.reduce_logsumexp(alphas, [1])

        return alphas_seq, betas_seq, log_norm
예제 #7
0
    def _multi_seq_fn():
        """Decoding of highest scoring sequence."""

        # For simplicity, in shape comments, denote:
        # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
        num_tags = potentials.get_shape()[2].value
        batch_size = array_ops.shape(potentials)[0]
        # Computes forward decoding. Get last score and backpointers.
        crf_fwd_cell = CrfNbestDecodeForwardRnnCell(transition_params, K)
        initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1])
        # initial_state = array_ops.squeeze(initial_state, axis=[1])  # [B, O]

        # Padding initital state to fit N-best format
        modified_initial_state = tf.transpose(initial_state, perm=[0, 2, 1])
        padding_for_init_state = tf.constant(-1.0e38, shape=[1, 1, 1])
        padding_for_init_state = tf.tile(
            padding_for_init_state, multiples=[batch_size, num_tags, K - 1])

        modified_initial_state = tf.concat(
            [modified_initial_state, padding_for_init_state], axis=2)

        modified_initial_state = tf.reshape(
            modified_initial_state, shape=[array_ops.shape(potentials)[0],
                                           -1])  # [B, O*K]

        inputs = array_ops.slice(potentials, [0, 1, 0],
                                 [-1, -1, -1])  # [B, T-1, O]

        # follow dynamic_rnn logic as a dynamic programming to get TopKs in each step
        backpointers, last_score = rnn.dynamic_rnn(  # [B, T - 1, O*K], [B, O*K]
            crf_fwd_cell,
            inputs=inputs,
            sequence_length=sequence_length - 1,
            initial_state=modified_initial_state,
            time_major=False,
            dtype=dtypes.int32)

        backpointers = gen_array_ops.reverse_sequence(  # [B, T - 1, O*K]
            backpointers, sequence_length - 1, seq_dim=1)

        # Computes backward decoding. Extract tag indices from backpointers.
        crf_bwd_cell = CrfNbestDecodeBackwardRnnCell(num_tags, K)

        top_K_values, top_K_indices = tf.nn.top_k(last_score, K)

        initial_state = math_ops.cast(
            top_K_indices,  # [B, K]
            dtype=dtypes.int32)

        decode_tags, _ = rnn.dynamic_rnn(  # [B, T - 1, K]
            crf_bwd_cell,
            inputs=backpointers,
            sequence_length=sequence_length - 1,
            initial_state=initial_state,
            time_major=False,
            dtype=dtypes.int32)

        initial_state = array_ops.expand_dims(initial_state,
                                              axis=[1])  # [B, 1, K]
        decode_tags = array_ops.concat(
            [initial_state, decode_tags],  # [B, T, K]
            axis=1)
        decode_tags = gen_array_ops.reverse_sequence(  # [B, T, K]
            decode_tags, sequence_length, seq_dim=1)

        # if K > num_tag ^ seq_len ( all possible hypothesis), use num_tag ^ seq_len as to trim
        log_cnt_total_cases = tf.cast(
            sequence_length[0], dtypes.float32) * tf.log(
                tf.cast(tf.constant(num_tags), dtypes.float32))
        K_modified = tf.cond(
            tf.logical_or(
                tf.less(tf.log(tf.cast(tf.constant(K), dtypes.float32)),
                        log_cnt_total_cases),
                tf.less(log_cnt_total_cases,
                        tf.constant(0.0))), lambda: tf.constant(K),
            lambda: tf.pow(tf.constant(num_tags), sequence_length[0]))
        # K_modified = tf.constant(K)

        decode_tags = tf.transpose(decode_tags, perm=[0, 2, 1])  # [B, K, T]
        decode_tags = decode_tags / tf.constant(K)
        decode_tags = tf.floor(decode_tags)
        decode_tags = math_ops.cast(decode_tags, dtype=dtypes.int32)

        best_score = top_K_values  # [B, K]

        decode_tags = tf.slice(decode_tags, [0, 0, 0], [-1, K_modified, -1])
        best_score = tf.slice(best_score, [0, 0], [-1, K_modified])

        return decode_tags, best_score