def _create_topk_unique(inputs, k): """Creates the top k values in sorted order with indices.""" height = inputs.shape[0] width = inputs.shape[1] neg_inf_r0 = tf.constant(-np.inf, dtype=tf.float32) ones = tf.ones([height, width], dtype=tf.float32) neg_inf_r2 = ones * neg_inf_r0 inputs = tf.where(tf.is_nan(inputs), neg_inf_r2, inputs) tmp = inputs topk_r2 = tf.zeros([height, k], dtype=tf.float32) for i in range(k): kth_order_statistic = tf.reduce_max(tmp, axis=1, keepdims=True) k_mask = tf.tile( tf.expand_dims(tf.equal(tf.range(k), tf.fill([k], i)), 0), [height, 1]) topk_r2 = tf.where(k_mask, tf.tile(kth_order_statistic, [1, k]), topk_r2) ge_r2 = tf.greater_equal(inputs, tf.tile(kth_order_statistic, [1, width])) tmp = tf.where(ge_r2, neg_inf_r2, inputs) log2_ceiling = int(math.ceil(math.log(float(int(width)), 2))) next_power_of_two = 1 << log2_ceiling count_mask = next_power_of_two - 1 mask_r0 = tf.constant(count_mask) mask_r2 = tf.fill([height, k], mask_r0) topk_r2_s32 = tf.bitcast(topk_r2, tf.int32) topk_indices_r2 = tf.bitwise.bitwise_and(topk_r2_s32, mask_r2) return topk_r2, topk_indices_r2
def splice(feat, left_context, right_context): ''' splice frame with context param: feat, tf.float32, [batch, time, feat] return: feat, tf.float32, [batch, time, feat*(left_context + 1 + right_context)] reference: https://github.com/kaldi-asr/kaldi/src/feat/feature-functions.cc#L205:6 ''' def _loop_continue(time, end_time, context, unused_left_context, right_context, unused_output_tas): del unused_output_tas del unused_left_context return time < end_time def _loop_body(time, end_time, context, left_context, right_context, output_tas): shape = tf.shape(context) B, _, D = shape[0], shape[1], shape[2] N = (1 + left_context + right_context) * D new_feat = context[:, time:time + left_context + 1 + right_context, :] new_feat = tf.reshape(new_feat, [B, N]) new_output_tas = output_tas.write(time, new_feat) return (time + 1, end_time, context, left_context, right_context, new_output_tas) with tf.control_dependencies([ tf.assert_greater_equal(left_context, 0), tf.assert_greater_equal(right_context, 0) ]): T = tf.shape(feat)[1] output_tas = _new_tensor_array('splice_feat_ta', T, dtype=tf.float32) time = tf.constant(0, tf.int32) first = tf.tile(feat[:, 0:1, :], [1, left_context, 1]) last = tf.tile(feat[:, -1:, :], [1, right_context, 1]) context = tf.concat([first, feat], axis=1) context = tf.concat([context, last], axis=1) loop_vars = (time, T, context, left_context, right_context, output_tas) parallel_iterations = 10 shape_invariants = tf.nest.map_structure( lambda t: tf.TensorShape(None), loop_vars) (time, end_time, context, left_context, right_context, output_tas) = tf.while_loop(_loop_continue, _loop_body, loop_vars=loop_vars, shape_invariants=shape_invariants, parallel_iterations=parallel_iterations, swap_memory=False) del context del left_context del right_context batch_spliced_feats = output_tas.stack() batch_spliced_feats = tf.transpose(batch_spliced_feats, [1, 0, 2]) return batch_spliced_feats
def call(self, inputs, training=None, mask=None): batch_size = tf.shape(inputs)[0] W_3d = tf.tile(tf.expand_dims(self.W, axis=0), tf.stack([batch_size, 1, 1])) # [batch_size, steps, features] input_projection = tf.matmul(inputs, W_3d) if self.use_bias: input_projection += self.b input_projection = tf.tanh(input_projection) # [batch_size, steps, 1] similaritys = tf.reduce_sum(tf.multiply(input_projection, self.attention_context_vector), axis=2, keep_dims=True) # [batch_size, steps, 1] if mask is not None: attention_weights = masked_softmax(similaritys, mask, axis=1) else: attention_weights = tf.nn.softmax(similaritys, axis=1) # [batch_size, features] attention_output = tf.reduce_sum(tf.multiply(inputs, attention_weights), axis=1) return attention_output
def _expand_to_beam_size(tensor, beam_size): """Tiles a given tensor by beam_size.""" tensor = tf.expand_dims(tensor, axis=1) tile_dims = [1] * tensor.shape.ndims tile_dims[1] = beam_size return tf.tile(tensor, tile_dims)
def call(self, tensors): """Attention layer.""" left, right = tensors len_left = left.shape[1] len_right = right.shape[1] tensor_left = tf.expand_dims(left, axis=2) tensor_right = tf.expand_dims(right, axis=1) tensor_left = tf.tile(tensor_left, [1, 1, len_right, 1]) tensor_right = tf.tile(tensor_right, [1, len_left, 1, 1]) tensor_merged = tf.concat([tensor_left, tensor_right], axis=-1) middle_output = self.middle_layer(tensor_merged) attn_scores = self.attn(middle_output) attn_scores = tf.squeeze(attn_scores, axis=3) exp_attn_scores = tf.exp( attn_scores - tf.reduce_max(attn_scores, axis=-1, keepdims=True)) exp_sum = tf.reduce_sum(exp_attn_scores, axis=-1, keepdims=True) attention_weights = exp_attn_scores / exp_sum return tf.matmul(attention_weights, right)
def splice_layer(x, name, context): ''' Splice a tensor along the last dimension with context. e.g.: t = [[[1, 2, 3], [4, 5, 6], [7, 8, 9]]] splice_tensor(t, [0, 1]) = [[[1, 2, 3, 4, 5, 6], [4, 5, 6, 7, 8, 9], [7, 8, 9, 7, 8, 9]]] Args: tensor: a tf.Tensor with shape (B, T, D) a.k.a. (N, H, W) context: a list of context offsets Returns: spliced tensor with shape (..., D * len(context)) ''' with tf.variable_scope(name): input_shape = tf.shape(x) B, T = input_shape[0], input_shape[1] context_len = len(context) array = tf.TensorArray(x.dtype, size=context_len) for idx, offset in enumerate(context): begin = offset end = T + offset if begin < 0: begin = 0 sliced = x[:, begin:end, :] tiled = tf.tile(x[:, 0:1, :], [1, abs(offset), 1]) final = tf.concat((tiled, sliced), axis=1) else: end = T sliced = x[:, begin:end, :] tiled = tf.tile(x[:, -1:, :], [1, abs(offset), 1]) final = tf.concat((sliced, tiled), axis=1) array = array.write(idx, final) spliced = array.stack() spliced = tf.transpose(spliced, (1, 2, 0, 3)) spliced = tf.reshape(spliced, (B, T, -1)) return spliced
def _reshape_mask(mask): """ repeat mask for multi head Input shape: (Batch size, steps) Output shape: (Batch size * head num, steps) """ if mask is None: return None seq_len = tf.shape(mask)[1] mask = tf.expand_dims(mask, axis=1) mask = tf.tile(mask, [1, self.head_num, 1]) return tf.reshape(mask, shape=(-1, seq_len))
def call(self, inps, training=None, mask=None): if not self.is_infer: dec_inp, enc_out = inps with tf.name_scope('while'): dec_out = self.decode(dec_inp, enc_out, training, mask) scores = self.final_dense(dec_out) return scores else: enc_out = inps init_ids = tf.cast( tf.ones([utils.shape_list(enc_out)[0]]) * self.sos_id, tf.int32) # Beam Search enc_shape = utils.shape_list(enc_out) enc_out = tf.tile(tf.expand_dims(enc_out, axis=1), [1, self.beam_size, 1, 1]) enc_out = tf.reshape( enc_out, [enc_shape[0] * self.beam_size, enc_shape[1], enc_shape[2]]) enc_mask = tf.tile(tf.expand_dims(mask, axis=1), [1, self.beam_size, 1, 1, 1]) enc_mask = tf.reshape(enc_mask, [enc_shape[0] * self.beam_size, 1, 1, -1]) def symbols_to_logits_fn(dec_inps): dec_out = self.decode(dec_inps, enc_out, training, enc_mask) scores = self.final_dense(dec_out) return scores[:, -1, :] decoded_ids, scores, _ = self.beam_search(symbols_to_logits_fn, init_ids, self.beam_size, self.max_dec_len, self.vocab_size, self.length_penalty, self.eos_id) decoded_ids = decoded_ids[:, 0, 1:] return decoded_ids
def call(self, inputs: list, **kwargs) -> typing.Any: """ The computation logic of DynamicPoolingLayer. :param inputs: two input tensors. """ self._validate_dpool_size() x, dpool_index = inputs dpool_shape = tf.shape(dpool_index) batch_index_one = tf.expand_dims( tf.expand_dims(tf.range(dpool_shape[0]), axis=-1), axis=-1) batch_index = tf.expand_dims( tf.tile(batch_index_one, [1, self._msize1, self._msize2]), axis=-1) dpool_index_ex = tf.concat([batch_index, dpool_index], axis=3) x_expand = tf.gather_nd(x, dpool_index_ex) stride1 = self._msize1 // self._psize1 stride2 = self._msize2 // self._psize2 x_pool = tf.nn.max_pool(x_expand, [1, stride1, stride2, 1], [1, stride1, stride2, 1], "VALID") return x_pool
def _create_make_unique(inputs): """Replaces the lower bits of each element with iota.""" if inputs.shape.ndims != 2: raise ValueError("Input of top_k_with_unique must be rank-2 " "but got: %s" % inputs.shape) height = inputs.shape[0] width = inputs.shape[1] zeros = tf.zeros([height, width], dtype=tf.int32) log2_ceiling = int(math.ceil(math.log(int(width), 2))) next_power_of_two = 1 << log2_ceiling count_mask = ~(next_power_of_two - 1) count_mask_r0 = tf.constant(count_mask) count_mask_r2 = tf.fill([height, width], count_mask_r0) smallest_normal = 1 << 23 smallest_normal_r0 = tf.constant(smallest_normal, dtype=tf.int32) smallest_normal_r2 = tf.fill([height, width], smallest_normal_r0) low_bit_mask = ~(1 << 31) low_bit_mask_r0 = tf.constant(low_bit_mask, dtype=tf.int32) low_bit_mask_r2 = tf.fill([height, width], low_bit_mask_r0) iota = tf.tile(tf.expand_dims(tf.range(width, dtype=tf.int32), 0), [height, 1]) input_r2 = tf.bitcast(inputs, tf.int32) abs_r2 = tf.bitwise.bitwise_and(input_r2, low_bit_mask_r2) if_zero_r2 = tf.equal(abs_r2, zeros) smallest_normal_preserving_sign_r2 = tf.bitwise.bitwise_or( input_r2, smallest_normal_r2) input_no_zeros_r2 = tf.where(if_zero_r2, smallest_normal_preserving_sign_r2, input_r2) and_r2 = tf.bitwise.bitwise_and(input_no_zeros_r2, count_mask_r2) or_r2 = tf.bitwise.bitwise_or(and_r2, iota) return tf.bitcast(or_r2, tf.float32)
def beam_search(symbols_to_logits_fn, initial_ids, beam_size, decode_length, vocab_size, alpha, eos_id, states=None, stop_early=True, INF=1. * 1e20): """Beam search with length penalties.""" batch_size = utils.shape_list(initial_ids)[0] initial_log_probs = tf.constant([[0.] + [-INF] * (beam_size - 1)]) # (batch_size, beam_size) alive_log_probs = tf.tile(initial_log_probs, [batch_size, 1]) alive_seq = utils.expand_to_beam_size(initial_ids, beam_size) # (batch_size, beam_size, 1) alive_seq = tf.expand_dims(alive_seq, axis=2) if states: states = nest.map_structure( lambda state: utils.expand_to_beam_size(state, beam_size), states) else: states = {} # (batch_size, beam_size, 1) finished_seq = tf.zeros(utils.shape_list(alive_seq), tf.int32) # (batch_size, beam_size) finished_scores = tf.ones([batch_size, beam_size]) * -INF # (batch_size, beam_size) finished_flags = tf.zeros([batch_size, beam_size], tf.bool) def grow_finished(finished_seq, finished_scores, finished_flags, curr_seq, curr_scores, curr_finished): """ Given sequences and scores from finished sequence and current finished sequence , will gather the top k=beam size sequences to update finished seq. """ # padding zero for finished seq finished_seq = tf.concat( [finished_seq, tf.zeros([batch_size, beam_size, 1], tf.int32)], axis=2) # mask unfinished curr seq curr_scores += (1. - tf.to_float(curr_finished)) * -INF # concatenating the sequences and scores along beam axis # (batch_size, 2xbeam_size, seq_len) curr_finished_seq = tf.concat([finished_seq, curr_seq], axis=1) curr_finished_scores = tf.concat([finished_scores, curr_scores], axis=1) curr_finished_flags = tf.concat([finished_flags, curr_finished], axis=1) return utils.compute_topk_scores_and_seq( curr_finished_seq, curr_finished_scores, curr_finished_scores, curr_finished_flags, beam_size, batch_size, "grow_finished") def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished, states): """Given sequences and scores, will gather the top k=beam size sequences.""" curr_scores += tf.to_float(curr_finished) * -INF return utils.compute_topk_scores_and_seq(curr_seq, curr_scores, curr_log_probs, curr_finished, beam_size, batch_size, "grow_alive", states) def grow_topk(i, alive_seq, alive_log_probs, states): """Inner beam search loop.""" flat_ids = tf.reshape(alive_seq, [batch_size * beam_size, -1]) # (batch_size * beam_size, decoded_length) if states: flat_states = nest.map_structure(utils.merge_beam_dim, states) flat_logits, flat_states = symbols_to_logits_fn( flat_ids, i, flat_states) states = nest.map_structure( lambda t: utils.unmerge_beam_dim(t, batch_size, beam_size), flat_states) else: flat_logits = symbols_to_logits_fn(flat_ids) logits = tf.reshape(flat_logits, [batch_size, beam_size, -1]) candidate_log_probs = utils.log_prob_from_logits(logits) log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs, axis=2) length_penalty = tf.pow(((5. + tf.to_float(i + 1)) / 6.), alpha) curr_scores = log_probs / length_penalty flat_curr_scores = tf.reshape(curr_scores, [-1, beam_size * vocab_size]) topk_scores, topk_ids = tf.nn.top_k(flat_curr_scores, k=beam_size * 2) topk_log_probs = topk_scores * length_penalty topk_beam_index = topk_ids // vocab_size topk_ids %= vocab_size # Unflatten the ids batch_pos = utils.compute_batch_indices(batch_size, beam_size * 2) topk_coordinates = tf.stack([batch_pos, topk_beam_index], axis=2) topk_seq = tf.gather_nd(alive_seq, topk_coordinates) if states: states = nest.map_structure( lambda state: tf.gather_nd(state, topk_coordinates), states) topk_seq = tf.concat( [topk_seq, tf.expand_dims(topk_ids, axis=2)], axis=2) topk_finished = tf.equal(topk_ids, eos_id) return topk_seq, topk_log_probs, topk_scores, topk_finished, states def inner_loop(i, alive_seq, alive_log_probs, finished_seq, finished_scores, finished_flags, states): """Inner beam search loop.""" topk_seq, topk_log_probs, topk_scores, topk_finished, states = grow_topk( i, alive_seq, alive_log_probs, states) alive_seq, alive_log_probs, _, states = grow_alive( topk_seq, topk_scores, topk_log_probs, topk_finished, states) finished_seq, finished_scores, finished_flags, _ = grow_finished( finished_seq, finished_scores, finished_flags, topk_seq, topk_scores, topk_finished) return (i + 1, alive_seq, alive_log_probs, finished_seq, finished_scores, finished_flags, states) def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq, finished_scores, unused_finished_in_finished, unused_states): """Checking termination condition. """ max_length_penalty = tf.pow( ((5. + tf.to_float(decode_length)) / 6.), alpha) lower_bound_alive_scores = alive_log_probs[:, 0] / max_length_penalty if not stop_early: lowest_score_of_finished_in_finished = tf.reduce_min( finished_scores) else: lowest_score_of_finished_in_finished = tf.reduce_max( finished_scores, axis=1) bound_is_met = tf.reduce_all( tf.greater(lowest_score_of_finished_in_finished, lower_bound_alive_scores)) return tf.logical_and(tf.less(i, decode_length), tf.logical_not(bound_is_met)) inner_shape = tf.TensorShape([None, None, None]) state_struc = nest.map_structure(utils.get_state_shape_invariants, states) (_, alive_seq, alive_log_probs, finished_seq, finished_scores, finished_flags, states) = tf.while_loop( _is_finished, inner_loop, [ tf.constant(0), alive_seq, alive_log_probs, finished_seq, finished_scores, finished_flags, states ], shape_invariants=[ tf.TensorShape([]), inner_shape, alive_log_probs.get_shape(), inner_shape, finished_scores.get_shape(), finished_flags.get_shape(), state_struc ], parallel_iterations=1, back_prop=False) alive_seq.set_shape((None, beam_size, None)) finished_seq.set_shape((None, beam_size, None)) finished_seq = tf.where(tf.reduce_any(finished_flags, 1), finished_seq, alive_seq) finished_scores = tf.where(tf.reduce_any(finished_flags, 1), finished_scores, alive_log_probs) return finished_seq, finished_scores, states
def call(self, inputs, training=None, mask=None): dec_emb_fn = lambda ids: self.embed(ids) if self.is_infer: enc_outputs, enc_state, enc_seq_len = inputs batch_size = tf.shape(enc_outputs)[0] helper = seq2seq.GreedyEmbeddingHelper( embedding=dec_emb_fn, start_tokens=tf.fill([batch_size], self.dec_start_id), end_token=self.dec_end_id) else: dec_inputs, dec_seq_len, enc_outputs, enc_state, \ enc_seq_len = inputs batch_size = tf.shape(enc_outputs)[0] dec_inputs = self.embed(dec_inputs) helper = seq2seq.TrainingHelper( inputs=dec_inputs, sequence_length=dec_seq_len) if self.is_infer and self.beam_size > 1: tiled_enc_outputs = seq2seq.tile_batch( enc_outputs, multiplier=self.beam_size) tiled_seq_len = seq2seq.tile_batch(enc_seq_len, multiplier=self.beam_size) attn_mech = self._build_attention( enc_outputs=tiled_enc_outputs, enc_seq_len=tiled_seq_len) dec_cell = seq2seq.AttentionWrapper(self.cell, attn_mech) tiled_enc_last_state = seq2seq.tile_batch( enc_state, multiplier=self.beam_size) tiled_dec_init_state = dec_cell.zero_state( batch_size=batch_size * self.beam_size, dtype=tf.float32) if self.initial_decode_state: tiled_dec_init_state = tiled_dec_init_state.clone( cell_state=tiled_enc_last_state) dec = seq2seq.BeamSearchDecoder( cell=dec_cell, embedding=dec_emb_fn, start_tokens=tf.tile([self.dec_start_id], [batch_size]), end_token=self.dec_end_id, initial_state=tiled_dec_init_state, beam_width=self.beam_size, output_layer=tf.layers.Dense(self.vocab_size), length_penalty_weight=self.length_penalty) else: attn_mech = self._build_attention( enc_outputs=enc_outputs, enc_seq_len=enc_seq_len) dec_cell = seq2seq.AttentionWrapper( cell=self.cell, attention_mechanism=attn_mech) dec_init_state = dec_cell.zero_state( batch_size=batch_size, dtype=tf.float32) if self.initial_decode_state: dec_init_state = dec_init_state.clone(cell_state=enc_state) dec = seq2seq.BasicDecoder( cell=dec_cell, helper=helper, initial_state=dec_init_state, output_layer=tf.layers.Dense(self.vocab_size)) if self.is_infer: dec_outputs, _, _ = \ seq2seq.dynamic_decode(decoder=dec, maximum_iterations=self.max_dec_len, swap_memory=self.swap_memory, output_time_major=self.time_major) return dec_outputs.predicted_ids[:, :, 0] else: dec_outputs, _, _ = \ seq2seq.dynamic_decode(decoder=dec, maximum_iterations=tf.reduce_max(dec_seq_len), swap_memory=self.swap_memory, output_time_major=self.time_major) return dec_outputs.rnn_output
def get_pos(inputs): """get position id""" batch_size, seq_len = tf.shape(inputs)[0], tf.shape(inputs)[1] position_ind = tf.tile(tf.expand_dims(tf.range(seq_len), 0), [batch_size, 1]) return position_ind
def call(self, inputs, training=None, mask=None): query, key, value = self._unpack(inputs) query_mask, key_mask, _ = self._unpack(mask) batch_size = tf.shape(query)[0] dimension_query = query.get_shape().as_list()[-1] seq_len = tf.shape(query)[-2] key_len = tf.shape(key)[-2] feature_dim = tf.shape(value)[-1] query = tf.matmul( query, tf.tile(tf.expand_dims(self.kernel_query, 0), [batch_size, 1, 1])) key = tf.matmul( key, tf.tile(tf.expand_dims(self.kernel_key, 0), [batch_size, 1, 1])) value = tf.matmul( value, tf.tile(tf.expand_dims(self.kernel_value, 0), [batch_size, 1, 1])) if self.use_bias: query += self.b_query key += self.b_key value += self.b_value def _reshape_multihead(origin_input): """ reshape for multi head Input shape: (Batch size, steps, features) Output shape: (Batch size * head num, steps, features // head num) """ return tf.concat(tf.split(origin_input, self.head_num, axis=2), axis=0) def _reshape_mask(mask): """ repeat mask for multi head Input shape: (Batch size, steps) Output shape: (Batch size * head num, steps) """ if mask is None: return None seq_len = tf.shape(mask)[1] mask = tf.expand_dims(mask, axis=1) mask = tf.tile(mask, [1, self.head_num, 1]) return tf.reshape(mask, shape=(-1, seq_len)) query_ = _reshape_multihead(query) key_ = _reshape_multihead(key) value_ = _reshape_multihead(value) key_mask = _reshape_mask(key_mask) # (Batch size * head num, query steps, key steps) similaritys = tf.matmul(query_, tf.transpose(key_, [0, 2, 1])) # scale similaritys /= tf.sqrt(tf.cast(dimension_query, tf.float32)) if self.sequence_mask: ones = tf.ones((seq_len, key_len)) similaritys -= (ones - tf.matrix_band_part(ones, -1, 0)) * 1e9 if key_mask is not None: similaritys -= (1.0 - tf.cast(tf.expand_dims(key_mask, axis=-2), tf.float32)) * 1e9 attention_weights = tf.keras.activations.softmax(similaritys) attention_outputs = tf.matmul(attention_weights, value_) attention_outputs = tf.reshape( attention_outputs, (-1, self.head_num, seq_len, feature_dim // self.head_num)) attention_outputs = tf.transpose(attention_outputs, [0, 2, 1, 3]) attention_outputs = tf.reshape(attention_outputs, (-1, seq_len, feature_dim)) attention_outputs = tf.matmul( attention_outputs, tf.tile(tf.expand_dims(self.kernel_project, 0), [batch_size, 1, 1])) if self.use_bias: attention_outputs += self.b_project if self.activation is not None: attention_outputs = self.activation(attention_outputs) if query_mask is not None: attention_outputs *= tf.cast(tf.expand_dims(query_mask, axis=-1), tf.float32) return attention_outputs