def _create_topk_unique(inputs, k): """Creates the top k values in sorted order with indices.""" height = inputs.shape[0] width = inputs.shape[1] neg_inf_r0 = tf.constant(-np.inf, dtype=tf.float32) ones = tf.ones([height, width], dtype=tf.float32) neg_inf_r2 = ones * neg_inf_r0 inputs = tf.where(tf.is_nan(inputs), neg_inf_r2, inputs) tmp = inputs topk_r2 = tf.zeros([height, k], dtype=tf.float32) for i in range(k): kth_order_statistic = tf.reduce_max(tmp, axis=1, keepdims=True) k_mask = tf.tile( tf.expand_dims(tf.equal(tf.range(k), tf.fill([k], i)), 0), [height, 1]) topk_r2 = tf.where(k_mask, tf.tile(kth_order_statistic, [1, k]), topk_r2) ge_r2 = tf.greater_equal(inputs, tf.tile(kth_order_statistic, [1, width])) tmp = tf.where(ge_r2, neg_inf_r2, inputs) log2_ceiling = int(math.ceil(math.log(float(int(width)), 2))) next_power_of_two = 1 << log2_ceiling count_mask = next_power_of_two - 1 mask_r0 = tf.constant(count_mask) mask_r2 = tf.fill([height, k], mask_r0) topk_r2_s32 = tf.bitcast(topk_r2, tf.int32) topk_indices_r2 = tf.bitwise.bitwise_and(topk_r2_s32, mask_r2) return topk_r2, topk_indices_r2
def split_one_doc_to_true_len_sens(doc_t, split_token, padding_token, max_doc_len, max_sen_len): """ Split a document to sentences with true sentence lengths. doc_t: [doc_word_len] out_t: [max_doc_len, max_sen_len] """ if len(doc_t.get_shape()) == 1: split_token_index = tf.squeeze(tf.where(tf.equal(doc_t, split_token)), axis=1) split_token_index.set_shape([None]) split_len_part_1 = split_token_index[:1] + 1 split_len_part_2 = split_token_index[1:] - split_token_index[:-1] split_lens = tf.concat([split_len_part_1, split_len_part_2], axis=0) split_lens = cut_or_padding(split_lens, max_doc_len, padding_token=padding_token) new_doc_len = tf.reduce_sum(split_lens) splited_sentences = tf.split(doc_t[:new_doc_len], split_lens) splited_sentences = [ cut_or_padding(s, max_sen_len) for s in splited_sentences ] out_t = tf.stack(splited_sentences) padding_tokens = tf.multiply(tf.ones_like(out_t, dtype=tf.int32), padding_token) out_t = tf.where(tf.equal(out_t, split_token), padding_tokens, out_t) return out_t raise ValueError("doc_t should be a tensor with rank 1.")
def masked_softmax(logits, mask, axis): """Compute softmax with input mask.""" e_logits = tf.exp(logits) masked_e = tf.multiply(e_logits, mask) sum_masked_e = tf.reduce_sum(masked_e, axis, keep_dims=True) ones = tf.ones_like(sum_masked_e) # pay attention to a situation that if len of mask is zero, # denominator should be set to 1 sum_masked_e_safe = tf.where(tf.equal(sum_masked_e, 0), ones, sum_masked_e) return masked_e / sum_masked_e_safe
def labels_blankid_to_last(labels, blank_index, num_class=None): ''' Change the value of blank_label elements from blank_index to num_class - 1''' assert num_class is not None, 'The num_class should not be None!' labels = transform_preprocess( labels=labels, blank_index=blank_index, num_class=num_class) labels_values = labels.values labels_num_class = tf.zeros_like(labels_values, dtype=tf.int32) + num_class labels_values_change_blank = tf.where( tf.equal(labels_values, blank_index), labels_num_class, labels_values) labels_values = tf.where(labels_values_change_blank < blank_index, labels_values_change_blank, labels_values_change_blank - 1) labels = tf.SparseTensor( indices=labels.indices, values=labels_values, dense_shape=labels.dense_shape) return labels
def compute_sen_lens(inputs, padding_token=0): """ Count how many words in a sentence. inputs: [..., time_steps] sen_lens: [...] """ x_binary = tf.cast(tf.not_equal(inputs, padding_token), tf.int32) sen_lens = tf.reduce_sum(x_binary, axis=-1) ones = tf.ones_like(sen_lens) sen_lens = tf.where(tf.equal(sen_lens, utils.PAD_IDX), x=ones, y=sen_lens) return sen_lens
def compute_lens(inputs, max_len): """count sequence length. input: [batch_size, max_len] lens: [batch_size] """ x_binary = tf.cast(tf.cast(tf.reverse(inputs, axis=[1]), tf.bool), tf.int32) lens = max_len - tf.argmax(x_binary, axis=1, output_type=tf.int32) zeros = tf.zeros_like(lens, dtype=tf.int32) x_sum = tf.reduce_sum(inputs, axis=1) sen_lens = tf.where(tf.equal(x_sum, 0), zeros, lens) return sen_lens
def labels_last_to_blankid(labels, blank_index, num_class=None): ''' Change the value of blank_label elements from num_classes - 1 to blank_index, after removing blank_index by decoder. ''' labels = transform_preprocess( labels=labels, blank_index=blank_index, num_class=num_class) labels_values = labels.values labels_change_blank_id = tf.where(labels_values >= blank_index, labels_values + 1, labels_values) labels = tf.SparseTensor( indices=labels.indices, values=labels_change_blank_id, dense_shape=labels.dense_shape) return labels
def transform_preprocess(labels=None, blank_index=None, num_class=None): ''' Ensure that the value of blank_index is in a reasonable range, and transform the DenseTensor labels to a SparseTensor ''' if blank_index is None or blank_index < 0: raise ValueError('blank_index must be greater than or equal to zero') if not num_class is None and blank_index > (num_class - 1): raise ValueError('blank_index must be less than or equal to num_class - 1') if labels is None: return None if not isinstance(labels, tf.SparseTensor): labels = tf.cast(labels, tf.int32) labels_idx = tf.where(tf.not_equal(labels, 0)) labels_values = tf.gather_nd(labels, labels_idx) labels_shape = tf.cast(tf.shape(labels), dtype=tf.int64) labels = tf.SparseTensor( indices=labels_idx, values=labels_values, dense_shape=labels_shape) return labels
def _create_make_unique(inputs): """Replaces the lower bits of each element with iota.""" if inputs.shape.ndims != 2: raise ValueError("Input of top_k_with_unique must be rank-2 " "but got: %s" % inputs.shape) height = inputs.shape[0] width = inputs.shape[1] zeros = tf.zeros([height, width], dtype=tf.int32) log2_ceiling = int(math.ceil(math.log(int(width), 2))) next_power_of_two = 1 << log2_ceiling count_mask = ~(next_power_of_two - 1) count_mask_r0 = tf.constant(count_mask) count_mask_r2 = tf.fill([height, width], count_mask_r0) smallest_normal = 1 << 23 smallest_normal_r0 = tf.constant(smallest_normal, dtype=tf.int32) smallest_normal_r2 = tf.fill([height, width], smallest_normal_r0) low_bit_mask = ~(1 << 31) low_bit_mask_r0 = tf.constant(low_bit_mask, dtype=tf.int32) low_bit_mask_r2 = tf.fill([height, width], low_bit_mask_r0) iota = tf.tile(tf.expand_dims(tf.range(width, dtype=tf.int32), 0), [height, 1]) input_r2 = tf.bitcast(inputs, tf.int32) abs_r2 = tf.bitwise.bitwise_and(input_r2, low_bit_mask_r2) if_zero_r2 = tf.equal(abs_r2, zeros) smallest_normal_preserving_sign_r2 = tf.bitwise.bitwise_or( input_r2, smallest_normal_r2) input_no_zeros_r2 = tf.where(if_zero_r2, smallest_normal_preserving_sign_r2, input_r2) and_r2 = tf.bitwise.bitwise_and(input_no_zeros_r2, count_mask_r2) or_r2 = tf.bitwise.bitwise_or(and_r2, iota) return tf.bitcast(or_r2, tf.float32)
def beam_search(symbols_to_logits_fn, initial_ids, beam_size, decode_length, vocab_size, alpha, eos_id, states=None, stop_early=True, INF=1. * 1e20): """Beam search with length penalties.""" batch_size = utils.shape_list(initial_ids)[0] initial_log_probs = tf.constant([[0.] + [-INF] * (beam_size - 1)]) # (batch_size, beam_size) alive_log_probs = tf.tile(initial_log_probs, [batch_size, 1]) alive_seq = utils.expand_to_beam_size(initial_ids, beam_size) # (batch_size, beam_size, 1) alive_seq = tf.expand_dims(alive_seq, axis=2) if states: states = nest.map_structure( lambda state: utils.expand_to_beam_size(state, beam_size), states) else: states = {} # (batch_size, beam_size, 1) finished_seq = tf.zeros(utils.shape_list(alive_seq), tf.int32) # (batch_size, beam_size) finished_scores = tf.ones([batch_size, beam_size]) * -INF # (batch_size, beam_size) finished_flags = tf.zeros([batch_size, beam_size], tf.bool) def grow_finished(finished_seq, finished_scores, finished_flags, curr_seq, curr_scores, curr_finished): """ Given sequences and scores from finished sequence and current finished sequence , will gather the top k=beam size sequences to update finished seq. """ # padding zero for finished seq finished_seq = tf.concat( [finished_seq, tf.zeros([batch_size, beam_size, 1], tf.int32)], axis=2) # mask unfinished curr seq curr_scores += (1. - tf.to_float(curr_finished)) * -INF # concatenating the sequences and scores along beam axis # (batch_size, 2xbeam_size, seq_len) curr_finished_seq = tf.concat([finished_seq, curr_seq], axis=1) curr_finished_scores = tf.concat([finished_scores, curr_scores], axis=1) curr_finished_flags = tf.concat([finished_flags, curr_finished], axis=1) return utils.compute_topk_scores_and_seq( curr_finished_seq, curr_finished_scores, curr_finished_scores, curr_finished_flags, beam_size, batch_size, "grow_finished") def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished, states): """Given sequences and scores, will gather the top k=beam size sequences.""" curr_scores += tf.to_float(curr_finished) * -INF return utils.compute_topk_scores_and_seq(curr_seq, curr_scores, curr_log_probs, curr_finished, beam_size, batch_size, "grow_alive", states) def grow_topk(i, alive_seq, alive_log_probs, states): """Inner beam search loop.""" flat_ids = tf.reshape(alive_seq, [batch_size * beam_size, -1]) # (batch_size * beam_size, decoded_length) if states: flat_states = nest.map_structure(utils.merge_beam_dim, states) flat_logits, flat_states = symbols_to_logits_fn( flat_ids, i, flat_states) states = nest.map_structure( lambda t: utils.unmerge_beam_dim(t, batch_size, beam_size), flat_states) else: flat_logits = symbols_to_logits_fn(flat_ids) logits = tf.reshape(flat_logits, [batch_size, beam_size, -1]) candidate_log_probs = utils.log_prob_from_logits(logits) log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs, axis=2) length_penalty = tf.pow(((5. + tf.to_float(i + 1)) / 6.), alpha) curr_scores = log_probs / length_penalty flat_curr_scores = tf.reshape(curr_scores, [-1, beam_size * vocab_size]) topk_scores, topk_ids = tf.nn.top_k(flat_curr_scores, k=beam_size * 2) topk_log_probs = topk_scores * length_penalty topk_beam_index = topk_ids // vocab_size topk_ids %= vocab_size # Unflatten the ids batch_pos = utils.compute_batch_indices(batch_size, beam_size * 2) topk_coordinates = tf.stack([batch_pos, topk_beam_index], axis=2) topk_seq = tf.gather_nd(alive_seq, topk_coordinates) if states: states = nest.map_structure( lambda state: tf.gather_nd(state, topk_coordinates), states) topk_seq = tf.concat( [topk_seq, tf.expand_dims(topk_ids, axis=2)], axis=2) topk_finished = tf.equal(topk_ids, eos_id) return topk_seq, topk_log_probs, topk_scores, topk_finished, states def inner_loop(i, alive_seq, alive_log_probs, finished_seq, finished_scores, finished_flags, states): """Inner beam search loop.""" topk_seq, topk_log_probs, topk_scores, topk_finished, states = grow_topk( i, alive_seq, alive_log_probs, states) alive_seq, alive_log_probs, _, states = grow_alive( topk_seq, topk_scores, topk_log_probs, topk_finished, states) finished_seq, finished_scores, finished_flags, _ = grow_finished( finished_seq, finished_scores, finished_flags, topk_seq, topk_scores, topk_finished) return (i + 1, alive_seq, alive_log_probs, finished_seq, finished_scores, finished_flags, states) def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq, finished_scores, unused_finished_in_finished, unused_states): """Checking termination condition. """ max_length_penalty = tf.pow( ((5. + tf.to_float(decode_length)) / 6.), alpha) lower_bound_alive_scores = alive_log_probs[:, 0] / max_length_penalty if not stop_early: lowest_score_of_finished_in_finished = tf.reduce_min( finished_scores) else: lowest_score_of_finished_in_finished = tf.reduce_max( finished_scores, axis=1) bound_is_met = tf.reduce_all( tf.greater(lowest_score_of_finished_in_finished, lower_bound_alive_scores)) return tf.logical_and(tf.less(i, decode_length), tf.logical_not(bound_is_met)) inner_shape = tf.TensorShape([None, None, None]) state_struc = nest.map_structure(utils.get_state_shape_invariants, states) (_, alive_seq, alive_log_probs, finished_seq, finished_scores, finished_flags, states) = tf.while_loop( _is_finished, inner_loop, [ tf.constant(0), alive_seq, alive_log_probs, finished_seq, finished_scores, finished_flags, states ], shape_invariants=[ tf.TensorShape([]), inner_shape, alive_log_probs.get_shape(), inner_shape, finished_scores.get_shape(), finished_flags.get_shape(), state_struc ], parallel_iterations=1, back_prop=False) alive_seq.set_shape((None, beam_size, None)) finished_seq.set_shape((None, beam_size, None)) finished_seq = tf.where(tf.reduce_any(finished_flags, 1), finished_seq, alive_seq) finished_scores = tf.where(tf.reduce_any(finished_flags, 1), finished_scores, alive_log_probs) return finished_seq, finished_scores, states
def arcface_loss(embedding, labels, out_num, weights=None, s=64., m=0.5, limit_to_pi=True): ''' https://github.com/auroua/InsightFace_TF/blob/master/losses/face_losses.py :param embedding: the input embedding vectors :param labels: the input labels, the shape should be eg: (batch_size, 1) :param s: scalar value default is 64 :param out_num: output class num :param weights: a tf.variable with shape (embedding.shape[-1], out_num) or None to make a new one internally. default = None :param m: the margin value, default is 0.5 :return: the final cacualted output, this output is send into the tf.nn.softmax directly ''' cos_m = math.cos(m) sin_m = math.sin(m) mm = sin_m * m # issue 1 threshold = math.cos(math.pi - m) with tf.variable_scope('arcface_loss'): # inputs and weights norm embedding_norm = tf.norm(embedding, axis=1, keep_dims=True) embedding = tf.div(embedding, embedding_norm, name='norm_embedding') if weights is None: weights = tf.get_variable( name='weights', shape=[embedding.shape[-1].value, out_num], initializer=tf.initializer.glorot_unifrom()) weights_norm = tf.norm(weights, axis=0, keep_dims=True) weights = tf.div(weights, weights_norm, name='norm_weights') # cos(theta+m) cos_t = tf.matmul(embedding, weights, name='cos_t') cos_t2 = tf.square(cos_t, name='cos_2') sin_t2 = tf.subtract(1., cos_t2, name='sin_2') sin_t = tf.sqrt(sin_t2, name='sin_t') cos_mt = s * tf.subtract(tf.multiply(cos_t, cos_m), tf.multiply(sin_t, sin_m), name='cos_mt') if limit_to_pi: # this condition controls the theta+m should in range [0, pi] # 0<=theta+m<=pi # -m<=theta<=pi-m cond_v = cos_t - threshold cond = tf.cast(tf.nn.relu(cond_v, name='if_else'), dtype=tf.bool) keep_val = s * (cos_t - mm) cos_mt_temp = tf.where(cond, cos_mt, keep_val) else: cos_mt_temp = cos_mt mask = tf.one_hot(labels, depth=out_num, name='one_hot_mask') # mask = tf.squeeze(mask, 1) inv_mask = tf.subtract(1., mask, name='inverse_mask') s_cos_t = tf.multiply(s, cos_t, name='scalar_cos_t') output = tf.add(tf.multiply(s_cos_t, inv_mask), tf.multiply(cos_mt_temp, mask), name='arcface_loss_output') return output
def mask_outputs(origin_outputs): """mask position embedding""" inputs, outputs = origin_outputs outputs = tf.where(tf.equal(inputs, 0), inputs, outputs) return outputs