def call(self, inputs, training=None, mask=None): input_x = inputs["input_x"] if self.use_dense_task: dense_input = inputs["input_dense"] # [batch_size, max_len, embed_len] out = self.embed(input_x) if self.use_pretrained_model: logging.info("use_pretrained_model: {}, {}".format( self.pretrained_model_name, self.pretrained_model_mode)) if self.pretrained_model_name == 'elmo': input_px = self.get_pre_train_graph(input_x) input_px = tf.reshape(input_px, [-1, self.max_len, self.pretrained_model_dim]) out = tf.concat([out, input_px], axis=-1) out = tf.reduce_max(out, axis=1) if self.pretrained_model_name == 'bert': out = self.get_pre_train_graph(input_x) else: out = tf.reduce_max(out, axis=1) out = self.embed_d(out, training=training) if self.use_dense_input: dense_out = self.dense_input_linear(dense_input) if self.only_dense_input: out = dense_out else: out = tf.keras.layers.Concatenate()([out, dense_out]) # [batch_size, class_num] scores = self.final_dense(out) return scores
def curvature_range(self): # set up the curvature window self._curv_win = tf.Variable(np.zeros([ self._curv_win_width, ]), dtype=tf.float32, name="curv_win", trainable=False) # we can use log smoothing for curvature range to follow trend faster # self._curv_win = tf.scatter_update( # self._curv_win, self._global_step % self._curv_win_width, # tf.log(self._grad_norm_squared + EPS)) self._curv_win = tf.scatter_update( self._curv_win, self._global_step % self._curv_win_width, self._grad_norm_squared + EPS) # note here the iterations start from iteration 0 valid_window = tf.slice( self._curv_win, tf.constant([ 0, ]), tf.expand_dims(tf.minimum(tf.constant(self._curv_win_width), self._global_step + 1), dim=0)) if self._h_min_log_smooth: self._h_min_t = tf.log(tf.reduce_min(valid_window) + EPS) else: self._h_min_t = tf.reduce_min(valid_window) if self._h_max_log_smooth: self._h_max_t = tf.log(tf.reduce_max(valid_window) + EPS) else: self._h_max_t = tf.reduce_max(valid_window) curv_range_ops = [] with tf.control_dependencies([self._h_min_t, self._h_max_t]): avg_op = self._moving_averager.apply( [self._h_min_t, self._h_max_t]) with tf.control_dependencies([avg_op]): if self._h_min_log_smooth: self._h_min = tf.exp( tf.identity( self._moving_averager.average(self._h_min_t))) else: self._h_min = \ tf.identity(self._moving_averager.average(self._h_min_t)) if self._h_max_log_smooth: self._h_max = tf.exp( tf.identity( self._moving_averager.average(self._h_max_t))) else: self._h_max = \ tf.identity(self._moving_averager.average(self._h_max_t)) if self._sparsity_debias: self._h_min = self._h_min * self._sparsity_avg self._h_max = self._h_max * self._sparsity_avg curv_range_ops.append(avg_op) return curv_range_ops
def _create_topk_unique(inputs, k): """Creates the top k values in sorted order with indices.""" height = inputs.shape[0] width = inputs.shape[1] neg_inf_r0 = tf.constant(-np.inf, dtype=tf.float32) ones = tf.ones([height, width], dtype=tf.float32) neg_inf_r2 = ones * neg_inf_r0 inputs = tf.where(tf.is_nan(inputs), neg_inf_r2, inputs) tmp = inputs topk_r2 = tf.zeros([height, k], dtype=tf.float32) for i in range(k): kth_order_statistic = tf.reduce_max(tmp, axis=1, keepdims=True) k_mask = tf.tile( tf.expand_dims(tf.equal(tf.range(k), tf.fill([k], i)), 0), [height, 1]) topk_r2 = tf.where(k_mask, tf.tile(kth_order_statistic, [1, k]), topk_r2) ge_r2 = tf.greater_equal(inputs, tf.tile(kth_order_statistic, [1, width])) tmp = tf.where(ge_r2, neg_inf_r2, inputs) log2_ceiling = int(math.ceil(math.log(float(int(width)), 2))) next_power_of_two = 1 << log2_ceiling count_mask = next_power_of_two - 1 mask_r0 = tf.constant(count_mask) mask_r2 = tf.fill([height, k], mask_r0) topk_r2_s32 = tf.bitcast(topk_r2, tf.int32) topk_indices_r2 = tf.bitwise.bitwise_and(topk_r2_s32, mask_r2) return topk_r2, topk_indices_r2
def call(self, tensors): """Attention layer.""" left, right = tensors len_left = left.shape[1] len_right = right.shape[1] tensor_left = tf.expand_dims(left, axis=2) tensor_right = tf.expand_dims(right, axis=1) tensor_left = tf.tile(tensor_left, [1, 1, len_right, 1]) tensor_right = tf.tile(tensor_right, [1, len_left, 1, 1]) tensor_merged = tf.concat([tensor_left, tensor_right], axis=-1) middle_output = self.middle_layer(tensor_merged) attn_scores = self.attn(middle_output) attn_scores = tf.squeeze(attn_scores, axis=3) exp_attn_scores = tf.exp( attn_scores - tf.reduce_max(attn_scores, axis=-1, keepdims=True)) exp_sum = tf.reduce_sum(exp_attn_scores, axis=-1, keepdims=True) attention_weights = exp_attn_scores / exp_sum return tf.matmul(attention_weights, right)
def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq, finished_scores, unused_finished_in_finished, unused_states): """Checking termination condition. """ max_length_penalty = tf.pow( ((5. + tf.to_float(decode_length)) / 6.), alpha) lower_bound_alive_scores = alive_log_probs[:, 0] / max_length_penalty if not stop_early: lowest_score_of_finished_in_finished = tf.reduce_min( finished_scores) else: lowest_score_of_finished_in_finished = tf.reduce_max( finished_scores, axis=1) bound_is_met = tf.reduce_all( tf.greater(lowest_score_of_finished_in_finished, lower_bound_alive_scores)) return tf.logical_and(tf.less(i, decode_length), tf.logical_not(bound_is_met))
def compute_mel_filterbank_features(waveforms, sample_rate=16000, dither=1.0 / np.iinfo(np.int16).max, preemphasis=0.97, frame_length=25, frame_step=10, fft_length=None, window_fn=functools.partial( tf.signal.hann_window, periodic=True), lower_edge_hertz=80.0, upper_edge_hertz=7600.0, num_mel_bins=80, log_noise_floor=1e-3, apply_mask=True): """Implement mel-filterbank extraction using tf ops. Args: waveforms: float32 tensor with shape [batch_size, max_len] sample_rate: sampling rate of the waveform dither: stddev of Gaussian noise added to waveform to prevent quantization artefacts preemphasis: waveform high-pass filtering constant frame_length: frame length in ms frame_step: frame_Step in ms fft_length: number of fft bins window_fn: windowing function lower_edge_hertz: lowest frequency of the filterbank upper_edge_hertz: highest frequency of the filterbank num_mel_bins: filterbank size log_noise_floor: clip small values to prevent numeric overflow in log apply_mask: When working on a batch of samples, set padding frames to zero Returns: filterbanks: a float32 tensor with shape [batch_size, len, num_bins, 1] """ # is a complex64 Tensor representing the short-time Fourier # Transform of each signal in . Its shape is # [batch_size, ?, fft_unique_bins] # where fft_unique_bins = fft_length // 2 + 1 # Find the wave length: the largest index for which the value is !=0 # note that waveforms samples that are exactly 0.0 are quite common, so # simply doing sum(waveforms != 0, axis=-1) will not work correctly. wav_lens = tf.reduce_max( tf.expand_dims(tf.range(tf.shape(waveforms)[1]), 0) * tf.to_int32(tf.not_equal(waveforms, 0.0)), axis=-1) + 1 if dither > 0: waveforms += tf.random_normal(tf.shape(waveforms), stddev=dither) if preemphasis > 0: waveforms = waveforms[:, 1:] - preemphasis * waveforms[:, :-1] wav_lens -= 1 frame_length = int(frame_length * sample_rate / 1e3) frame_step = int(frame_step * sample_rate / 1e3) if fft_length is None: fft_length = int(2**(np.ceil(np.log2(frame_length)))) stfts = tf.signal.stft(waveforms, frame_length=frame_length, frame_step=frame_step, fft_length=fft_length, window_fn=window_fn, pad_end=True) stft_lens = (wav_lens + (frame_step - 1)) // frame_step masks = tf.to_float( tf.less_equal(tf.expand_dims(tf.range(tf.shape(stfts)[1]), 0), tf.expand_dims(stft_lens, 1))) # An energy spectrogram is the magnitude of the complex-valued STFT. # A float32 Tensor of shape [batch_size, ?, 257]. magnitude_spectrograms = tf.abs(stfts) # Warp the linear-scale, magnitude spectrograms into the mel-scale. num_spectrogram_bins = magnitude_spectrograms.shape[-1].value linear_to_mel_weight_matrix = (tf.signal.linear_to_mel_weight_matrix( num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz, upper_edge_hertz)) mel_spectrograms = tf.tensordot(magnitude_spectrograms, linear_to_mel_weight_matrix, 1) # Note: Shape inference for tensordot does not currently handle this case. mel_spectrograms.set_shape(magnitude_spectrograms.shape[:-1].concatenate( linear_to_mel_weight_matrix.shape[-1:])) log_mel_sgram = tf.log(tf.maximum(log_noise_floor, mel_spectrograms)) if apply_mask: log_mel_sgram *= tf.expand_dims(tf.to_float(masks), -1) return tf.expand_dims(log_mel_sgram, -1, name="mel_sgrams")
def call(self, inputs, training=None, mask=None): dec_emb_fn = lambda ids: self.embed(ids) if self.is_infer: enc_outputs, enc_state, enc_seq_len = inputs batch_size = tf.shape(enc_outputs)[0] helper = seq2seq.GreedyEmbeddingHelper( embedding=dec_emb_fn, start_tokens=tf.fill([batch_size], self.dec_start_id), end_token=self.dec_end_id) else: dec_inputs, dec_seq_len, enc_outputs, enc_state, \ enc_seq_len = inputs batch_size = tf.shape(enc_outputs)[0] dec_inputs = self.embed(dec_inputs) helper = seq2seq.TrainingHelper( inputs=dec_inputs, sequence_length=dec_seq_len) if self.is_infer and self.beam_size > 1: tiled_enc_outputs = seq2seq.tile_batch( enc_outputs, multiplier=self.beam_size) tiled_seq_len = seq2seq.tile_batch(enc_seq_len, multiplier=self.beam_size) attn_mech = self._build_attention( enc_outputs=tiled_enc_outputs, enc_seq_len=tiled_seq_len) dec_cell = seq2seq.AttentionWrapper(self.cell, attn_mech) tiled_enc_last_state = seq2seq.tile_batch( enc_state, multiplier=self.beam_size) tiled_dec_init_state = dec_cell.zero_state( batch_size=batch_size * self.beam_size, dtype=tf.float32) if self.initial_decode_state: tiled_dec_init_state = tiled_dec_init_state.clone( cell_state=tiled_enc_last_state) dec = seq2seq.BeamSearchDecoder( cell=dec_cell, embedding=dec_emb_fn, start_tokens=tf.tile([self.dec_start_id], [batch_size]), end_token=self.dec_end_id, initial_state=tiled_dec_init_state, beam_width=self.beam_size, output_layer=tf.layers.Dense(self.vocab_size), length_penalty_weight=self.length_penalty) else: attn_mech = self._build_attention( enc_outputs=enc_outputs, enc_seq_len=enc_seq_len) dec_cell = seq2seq.AttentionWrapper( cell=self.cell, attention_mechanism=attn_mech) dec_init_state = dec_cell.zero_state( batch_size=batch_size, dtype=tf.float32) if self.initial_decode_state: dec_init_state = dec_init_state.clone(cell_state=enc_state) dec = seq2seq.BasicDecoder( cell=dec_cell, helper=helper, initial_state=dec_init_state, output_layer=tf.layers.Dense(self.vocab_size)) if self.is_infer: dec_outputs, _, _ = \ seq2seq.dynamic_decode(decoder=dec, maximum_iterations=self.max_dec_len, swap_memory=self.swap_memory, output_time_major=self.time_major) return dec_outputs.predicted_ids[:, :, 0] else: dec_outputs, _, _ = \ seq2seq.dynamic_decode(decoder=dec, maximum_iterations=tf.reduce_max(dec_seq_len), swap_memory=self.swap_memory, output_time_major=self.time_major) return dec_outputs.rnn_output