def crf_decode(potentials, transition_params, sequence_length): """Decode the highest scoring sequence of tags in TensorFlow. This is a function for tensor. Args: potentials: A [batch_size, max_seq_len, num_tags] tensor of unary potentials. transition_params: A [num_tags, num_tags] matrix of binary potentials. sequence_length: A [batch_size] vector of true sequence lengths. Returns: decode_tags: A [batch_size, max_seq_len] tensor, with dtype tf.int32. Contains the highest scoring tag indices. best_score: A [batch_size] tensor, containing the score of decode_tags. """ # For simplicity, in shape comments, denote: # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). num_tags = potentials.get_shape()[2].value # Computes forward decoding. Get last score and backpointers. crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1]) initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O] inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] backpointers, last_score = rnn.dynamic_rnn( crf_fwd_cell, inputs=inputs, sequence_length=sequence_length - 1, initial_state=initial_state, time_major=False, dtype=dtypes.int32) # [B, T - 1, O], [B, O] backpointers = gen_array_ops.reverse_sequence(backpointers, sequence_length - 1, seq_dim=1) # [B, T-1, O] # Computes backward decoding. Extract tag indices from backpointers. crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags) initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1), dtype=dtypes.int32) # [B] initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1] decode_tags, _ = rnn.dynamic_rnn(crf_bwd_cell, inputs=backpointers, sequence_length=sequence_length - 1, initial_state=initial_state, time_major=False, dtype=dtypes.int32) # [B, T - 1, 1] decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1] decode_tags = array_ops.concat([initial_state, decode_tags], axis=1) # [B, T] decode_tags = gen_array_ops.reverse_sequence(decode_tags, sequence_length, seq_dim=1) # [B, T] best_score = math_ops.reduce_max(last_score, axis=1) # [B] return decode_tags, best_score
def crf_decode(potentials, transition_params, sequence_length): """Decode the highest scoring sequence of tags in TensorFlow. This is a function for tensor. Args: potentials: A [batch_size, max_seq_len, num_tags] tensor of unary potentials. transition_params: A [num_tags, num_tags] matrix of binary potentials. sequence_length: A [batch_size] vector of true sequence lengths. Returns: decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`. Contains the highest scoring tag indicies. best_score: A [batch_size] vector, containing the score of `decode_tags`. """ # For simplicity, in shape comments, denote: # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). num_tags = potentials.get_shape()[2].value # Computes forward decoding. Get last score and backpointers. crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1]) initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O] inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] backpointers, last_score = rnn.dynamic_rnn( crf_fwd_cell, inputs=inputs, sequence_length=sequence_length - 1, initial_state=initial_state, time_major=False, dtype=dtypes.int32) # [B, T - 1, O], [B, O] backpointers = gen_array_ops.reverse_sequence( backpointers, sequence_length - 1, seq_dim=1) # [B, T-1, O] # Computes backward decoding. Extract tag indices from backpointers. crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags) initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1), dtype=dtypes.int32) # [B] initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1] decode_tags, _ = rnn.dynamic_rnn( crf_bwd_cell, inputs=backpointers, sequence_length=sequence_length - 1, initial_state=initial_state, time_major=False, dtype=dtypes.int32) # [B, T - 1, 1] decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1] decode_tags = array_ops.concat([initial_state, decode_tags], axis=1) # [B, T] decode_tags = gen_array_ops.reverse_sequence( decode_tags, sequence_length, seq_dim=1) # [B, T] best_score = math_ops.reduce_max(last_score, axis=1) # [B] return decode_tags, best_score
def _multi_seq_fn(): """Decoding of highest scoring sequence.""" # For simplicity, in shape comments, denote: # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). num_tags = tensor_shape.dimension_value(potentials.shape[2]) # Computes forward decoding. Get last score and backpointers. crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1]) initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O] inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] # Sequence length is not allowed to be less than zero. sequence_length_less_one = math_ops.maximum( constant_op.constant(0, dtype=sequence_length.dtype), sequence_length - 1) backpointers, last_score = rnn.dynamic_rnn( # [B, T - 1, O], [B, O] crf_fwd_cell, inputs=inputs, sequence_length=sequence_length_less_one, initial_state=initial_state, time_major=False, dtype=dtypes.int32) backpointers = gen_array_ops.reverse_sequence( # [B, T - 1, O] backpointers, sequence_length_less_one, seq_dim=1) # Computes backward decoding. Extract tag indices from backpointers. crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags) initial_state = math_ops.cast( math_ops.argmax(last_score, axis=1), # [B] dtype=dtypes.int32) initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1] decode_tags, _ = rnn.dynamic_rnn( # [B, T - 1, 1] crf_bwd_cell, inputs=backpointers, sequence_length=sequence_length_less_one, initial_state=initial_state, time_major=False, dtype=dtypes.int32) decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1] decode_tags = array_ops.concat( [initial_state, decode_tags], # [B, T] axis=1) decode_tags = gen_array_ops.reverse_sequence( # [B, T] decode_tags, sequence_length, seq_dim=1) best_score = math_ops.reduce_max(last_score, axis=1) # [B] return decode_tags, best_score
def _multi_seq_fn(): """Decoding of highest scoring sequence.""" # For simplicity, in shape comments, denote: # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). num_tags = potentials.get_shape()[2].value # Computes forward decoding. Get last score and backpointers. crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1]) initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O] inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] # Sequence length is not allowed to be less than zero. sequence_length_less_one = math_ops.maximum( constant_op.constant(0, dtype=sequence_length.dtype), sequence_length - 1) backpointers, last_score = rnn.dynamic_rnn( # [B, T - 1, O], [B, O] crf_fwd_cell, inputs=inputs, sequence_length=sequence_length_less_one, initial_state=initial_state, time_major=False, dtype=dtypes.int32) backpointers = gen_array_ops.reverse_sequence( # [B, T - 1, O] backpointers, sequence_length_less_one, seq_dim=1) # Computes backward decoding. Extract tag indices from backpointers. crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags) initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1), # [B] dtype=dtypes.int32) initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1] decode_tags, _ = rnn.dynamic_rnn( # [B, T - 1, 1] crf_bwd_cell, inputs=backpointers, sequence_length=sequence_length_less_one, initial_state=initial_state, time_major=False, dtype=dtypes.int32) decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1] decode_tags = array_ops.concat([initial_state, decode_tags], # [B, T] axis=1) decode_tags = gen_array_ops.reverse_sequence( # [B, T] decode_tags, sequence_length, seq_dim=1) best_score = math_ops.reduce_max(last_score, axis=1) # [B] return decode_tags, best_score
def crf_decode(self, potentials, seq_lens): crf_fwd_cell = CrfDecodeForwardRnnCell(self.transition_params) initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1]) initial_state = array_ops.squeeze(initial_state, axis=[1]) inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) seq_len_less_one = math_ops.maximum(constant_op.constant(0, dtype=seq_lens.dtype), seq_lens - 1) backpointers, last_score = rnn.dynamic_rnn(crf_fwd_cell, inputs=inputs, initial_state=initial_state, sequence_length=seq_len_less_one, time_major=False, dtype=dtypes.int32) backpointers = gen_array_ops.reverse_sequence(backpointers, seq_len_less_one, seq_dim=1) crf_bwd_cell = CrfDecodeBackwardRnnCell(self.num_tags) initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1), dtype=dtypes.int32) initial_state = array_ops.expand_dims(initial_state, axis=-1) decode_tags, _ = rnn.dynamic_rnn(crf_bwd_cell, inputs=backpointers, sequence_length=seq_len_less_one, initial_state=initial_state, time_major=False, dtype=dtypes.int32) decode_tags = array_ops.squeeze(decode_tags, axis=[2]) decode_tags = array_ops.concat([initial_state, decode_tags], axis=1) decode_tags = gen_array_ops.reverse_sequence(decode_tags, seq_lens, seq_dim=1) best_score = math_ops.reduce_max(last_score, axis=1) return decode_tags, best_score
def _multi_seq_fn(): # Split up the first and rest of the inputs in preparation for the forward # algorithm. batch_size = array_ops.shape(inputs)[0] num_tags = array_ops.shape(inputs)[2] first_input = array_ops.slice(inputs, [0, 0, 0], [-1, 1, -1]) first_input = array_ops.squeeze(first_input, [1]) rest_of_input = array_ops.slice(inputs, [0, 1, 0], [-1, -1, -1]) # Compute the alpha values in the forward algorithm forward_cell = CrfForwardRnnCell(transition_params) alphas_seq, alphas = rnn.dynamic_rnn(cell=forward_cell, inputs=rest_of_input, sequence_length=sequence_lengths - 1, initial_state=first_input, dtype=dtypes.float32) # Get all alphas in each time steps alphas_seq = tf.concat( [tf.expand_dims(first_input, axis=1), alphas_seq], axis=1) # Compute the betas values in the backward algorithm first_input = tf.constant( 0.0, shape=[1, 1]) # as we use log, so 0.0 for beta initialization first_input = tf.tile(first_input, multiples=[batch_size, num_tags]) # reverse the sequence of inputs in forward algorithm for backward algorithm rest_of_input = gen_array_ops.reverse_sequence(rest_of_input, sequence_lengths - 1, seq_dim=1) # transpose transition parameters for backward algorithm backward_cell = CrfBackwardRnnCell( tf.transpose(transition_params, perm=[1, 0])) betas_seq, betas = rnn.dynamic_rnn(cell=backward_cell, inputs=rest_of_input, sequence_length=sequence_lengths - 1, initial_state=first_input, dtype=dtypes.float32) betas_seq = tf.concat([tf.expand_dims(first_input, axis=1), betas_seq], axis=1) # reverse betas that follows same index as alphas betas_seq = tf.reverse_sequence(betas_seq, sequence_lengths, seq_dim=1) # crf log norm log_norm = math_ops.reduce_logsumexp(alphas, [1]) return alphas_seq, betas_seq, log_norm
def _multi_seq_fn(): """Decoding of highest scoring sequence.""" # For simplicity, in shape comments, denote: # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). num_tags = potentials.get_shape()[2].value batch_size = array_ops.shape(potentials)[0] # Computes forward decoding. Get last score and backpointers. crf_fwd_cell = CrfNbestDecodeForwardRnnCell(transition_params, K) initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1]) # initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O] # Padding initital state to fit N-best format modified_initial_state = tf.transpose(initial_state, perm=[0, 2, 1]) padding_for_init_state = tf.constant(-1.0e38, shape=[1, 1, 1]) padding_for_init_state = tf.tile( padding_for_init_state, multiples=[batch_size, num_tags, K - 1]) modified_initial_state = tf.concat( [modified_initial_state, padding_for_init_state], axis=2) modified_initial_state = tf.reshape( modified_initial_state, shape=[array_ops.shape(potentials)[0], -1]) # [B, O*K] inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] # follow dynamic_rnn logic as a dynamic programming to get TopKs in each step backpointers, last_score = rnn.dynamic_rnn( # [B, T - 1, O*K], [B, O*K] crf_fwd_cell, inputs=inputs, sequence_length=sequence_length - 1, initial_state=modified_initial_state, time_major=False, dtype=dtypes.int32) backpointers = gen_array_ops.reverse_sequence( # [B, T - 1, O*K] backpointers, sequence_length - 1, seq_dim=1) # Computes backward decoding. Extract tag indices from backpointers. crf_bwd_cell = CrfNbestDecodeBackwardRnnCell(num_tags, K) top_K_values, top_K_indices = tf.nn.top_k(last_score, K) initial_state = math_ops.cast( top_K_indices, # [B, K] dtype=dtypes.int32) decode_tags, _ = rnn.dynamic_rnn( # [B, T - 1, K] crf_bwd_cell, inputs=backpointers, sequence_length=sequence_length - 1, initial_state=initial_state, time_major=False, dtype=dtypes.int32) initial_state = array_ops.expand_dims(initial_state, axis=[1]) # [B, 1, K] decode_tags = array_ops.concat( [initial_state, decode_tags], # [B, T, K] axis=1) decode_tags = gen_array_ops.reverse_sequence( # [B, T, K] decode_tags, sequence_length, seq_dim=1) # if K > num_tag ^ seq_len ( all possible hypothesis), use num_tag ^ seq_len as to trim log_cnt_total_cases = tf.cast( sequence_length[0], dtypes.float32) * tf.log( tf.cast(tf.constant(num_tags), dtypes.float32)) K_modified = tf.cond( tf.logical_or( tf.less(tf.log(tf.cast(tf.constant(K), dtypes.float32)), log_cnt_total_cases), tf.less(log_cnt_total_cases, tf.constant(0.0))), lambda: tf.constant(K), lambda: tf.pow(tf.constant(num_tags), sequence_length[0])) # K_modified = tf.constant(K) decode_tags = tf.transpose(decode_tags, perm=[0, 2, 1]) # [B, K, T] decode_tags = decode_tags / tf.constant(K) decode_tags = tf.floor(decode_tags) decode_tags = math_ops.cast(decode_tags, dtype=dtypes.int32) best_score = top_K_values # [B, K] decode_tags = tf.slice(decode_tags, [0, 0, 0], [-1, K_modified, -1]) best_score = tf.slice(best_score, [0, 0], [-1, K_modified]) return decode_tags, best_score