def curvature_range(self): # set up the curvature window self._curv_win = tf.Variable(np.zeros([ self._curv_win_width, ]), dtype=tf.float32, name="curv_win", trainable=False) # we can use log smoothing for curvature range to follow trend faster # self._curv_win = tf.scatter_update( # self._curv_win, self._global_step % self._curv_win_width, # tf.log(self._grad_norm_squared + EPS)) self._curv_win = tf.scatter_update( self._curv_win, self._global_step % self._curv_win_width, self._grad_norm_squared + EPS) # note here the iterations start from iteration 0 valid_window = tf.slice( self._curv_win, tf.constant([ 0, ]), tf.expand_dims(tf.minimum(tf.constant(self._curv_win_width), self._global_step + 1), dim=0)) if self._h_min_log_smooth: self._h_min_t = tf.log(tf.reduce_min(valid_window) + EPS) else: self._h_min_t = tf.reduce_min(valid_window) if self._h_max_log_smooth: self._h_max_t = tf.log(tf.reduce_max(valid_window) + EPS) else: self._h_max_t = tf.reduce_max(valid_window) curv_range_ops = [] with tf.control_dependencies([self._h_min_t, self._h_max_t]): avg_op = self._moving_averager.apply( [self._h_min_t, self._h_max_t]) with tf.control_dependencies([avg_op]): if self._h_min_log_smooth: self._h_min = tf.exp( tf.identity( self._moving_averager.average(self._h_min_t))) else: self._h_min = \ tf.identity(self._moving_averager.average(self._h_min_t)) if self._h_max_log_smooth: self._h_max = tf.exp( tf.identity( self._moving_averager.average(self._h_max_t))) else: self._h_max = \ tf.identity(self._moving_averager.average(self._h_max_t)) if self._sparsity_debias: self._h_min = self._h_min * self._sparsity_avg self._h_max = self._h_max * self._sparsity_avg curv_range_ops.append(avg_op) return curv_range_ops
def focal_loss(logits, labels, alpha, gamma=2, name='focal_loss'): """ Focal loss for multi classification :param logits: A float32 tensor of shape [batch_size num_class]. :param labels: A int32 tensor of shape [batch_size, num_class] or [batch_size]. :param alpha: A 1D float32 tensor for focal loss alpha hyper-parameter :param gamma: A scalar for focal loss gamma hyper-parameter. Returns: A tensor of the same shape as `lables` """ if len(labels.shape) == 1: labels = tf.one_hot(labels, logits.shape[-1]) else: labels = labels labels = tf.to_float(labels) y_pred = tf.nn.softmax(logits, dim=-1) L = -labels * tf.log(y_pred) L *= alpha * ((1 - y_pred)**gamma) loss = tf.reduce_sum(L) if tf.executing_eagerly(): tf.contrib.summary.scalar(name, loss) else: tf.summary.scalar(name, loss) return loss
def compute_mel_filterbank_features(waveforms, sample_rate=16000, dither=1.0 / np.iinfo(np.int16).max, preemphasis=0.97, frame_length=25, frame_step=10, fft_length=None, window_fn=functools.partial( tf.signal.hann_window, periodic=True), lower_edge_hertz=80.0, upper_edge_hertz=7600.0, num_mel_bins=80, log_noise_floor=1e-3, apply_mask=True): """Implement mel-filterbank extraction using tf ops. Args: waveforms: float32 tensor with shape [batch_size, max_len] sample_rate: sampling rate of the waveform dither: stddev of Gaussian noise added to waveform to prevent quantization artefacts preemphasis: waveform high-pass filtering constant frame_length: frame length in ms frame_step: frame_Step in ms fft_length: number of fft bins window_fn: windowing function lower_edge_hertz: lowest frequency of the filterbank upper_edge_hertz: highest frequency of the filterbank num_mel_bins: filterbank size log_noise_floor: clip small values to prevent numeric overflow in log apply_mask: When working on a batch of samples, set padding frames to zero Returns: filterbanks: a float32 tensor with shape [batch_size, len, num_bins, 1] """ # is a complex64 Tensor representing the short-time Fourier # Transform of each signal in . Its shape is # [batch_size, ?, fft_unique_bins] # where fft_unique_bins = fft_length // 2 + 1 # Find the wave length: the largest index for which the value is !=0 # note that waveforms samples that are exactly 0.0 are quite common, so # simply doing sum(waveforms != 0, axis=-1) will not work correctly. wav_lens = tf.reduce_max( tf.expand_dims(tf.range(tf.shape(waveforms)[1]), 0) * tf.to_int32(tf.not_equal(waveforms, 0.0)), axis=-1) + 1 if dither > 0: waveforms += tf.random_normal(tf.shape(waveforms), stddev=dither) if preemphasis > 0: waveforms = waveforms[:, 1:] - preemphasis * waveforms[:, :-1] wav_lens -= 1 frame_length = int(frame_length * sample_rate / 1e3) frame_step = int(frame_step * sample_rate / 1e3) if fft_length is None: fft_length = int(2**(np.ceil(np.log2(frame_length)))) stfts = tf.signal.stft(waveforms, frame_length=frame_length, frame_step=frame_step, fft_length=fft_length, window_fn=window_fn, pad_end=True) stft_lens = (wav_lens + (frame_step - 1)) // frame_step masks = tf.to_float( tf.less_equal(tf.expand_dims(tf.range(tf.shape(stfts)[1]), 0), tf.expand_dims(stft_lens, 1))) # An energy spectrogram is the magnitude of the complex-valued STFT. # A float32 Tensor of shape [batch_size, ?, 257]. magnitude_spectrograms = tf.abs(stfts) # Warp the linear-scale, magnitude spectrograms into the mel-scale. num_spectrogram_bins = magnitude_spectrograms.shape[-1].value linear_to_mel_weight_matrix = (tf.signal.linear_to_mel_weight_matrix( num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz, upper_edge_hertz)) mel_spectrograms = tf.tensordot(magnitude_spectrograms, linear_to_mel_weight_matrix, 1) # Note: Shape inference for tensordot does not currently handle this case. mel_spectrograms.set_shape(magnitude_spectrograms.shape[:-1].concatenate( linear_to_mel_weight_matrix.shape[-1:])) log_mel_sgram = tf.log(tf.maximum(log_noise_floor, mel_spectrograms)) if apply_mask: log_mel_sgram *= tf.expand_dims(tf.to_float(masks), -1) return tf.expand_dims(log_mel_sgram, -1, name="mel_sgrams")