def melspecgrams_to_stfts(melspecgrams: tf.Tensor, mel2l, ifreq=True) -> tf.Tensor: """Converts melspecgrams to stfts. Args: melspecgrams: Tensor of log magnitudes and instantaneous frequencies, shape [..., time, freq, 2*channels], mel scaling of frequencies. mel2l: Mel to linear matrix, ie transposed linear to mel matrix @see tf.signal.linear_to_mel_weight_matrix Returns: specgrams: Tensor of log magnitudes and instantaneous frequencies, shape [..., time, freq, channels]. """ melspecgrams_shape = shape_list(melspecgrams) # [..., time, freq, channels*2] melspecgrams = tf.reshape(melspecgrams, melspecgrams_shape[:-1] + [melspecgrams_shape[-1] // 2, 2]) # [..., time, freq, channels, 2] perm = list(range(len(melspecgrams_shape) + 1)) perm = perm[:-4] + [perm[-2], perm[-4], perm[-3], perm[-1]] melspecgrams = tf.transpose(melspecgrams, perm=perm) # [..., channels, time, freq, 2] stfts = _melspecgrams_to_stfts(melspecgrams, mel2l=mel2l, ifreq=True) # [..., channels, time, freq, 1] stfts = tf.squeeze(stfts, axis=-1) # [..., channels, time, freq] stfts_shape = shape_list(stfts) perm = list(range(len(stfts_shape))) perm = perm[:-3] + [perm[-2], perm[-1], perm[-3]] stfts = tf.transpose(stfts, perm=perm) # [..., time, freq, channels] return stfts
def call(self, inputs, training=None, mask=None, **kwargs): x, n_s = inputs x_shape = shape_list(x) if self.needs_squeeze: n_s = tf.squeeze(n_s, axis=1) kernels = embedding_ops.embedding_lookup( n_s, tf.reshape(self.get_weight('kernel', training=training), [self.n_kernels, x_shape[-1] * self.depth]), symbol_dropout_rate=0.) ks_shape = shape_list(kernels) kernels = tf.reshape(kernels, [ks_shape[0]] + [1] * (self.extra_dims_needed) + [x_shape[-1], self.depth]) x = tf.matmul(x, tf.linalg.matrix_transpose(kernels), transpose_b=True) if self.use_bias: biases = embedding_ops.embedding_lookup(n_s, self.get_weight( 'bias', training=training), symbol_dropout_rate=0.) biases = tf.reshape(biases, [ks_shape[0]] + [1] * (self.extra_dims_needed + 1) + [self.depth]) x += biases return self.activation(x)
def stfts_to_melspecgrams(stfts: tf.Tensor, l2mel, ifreq=True, return_phase=True) -> tf.Tensor: """Converts stfts to specgrams. Args: stfts: Complex64/Complex128 tensor of stft, shape [..., time, freq, channels]. Returns: melspecgrams: Tensor of log magnitudes and instantaneous frequencies, shape [..., time, freq, 2*channels], mel scaling of frequencies. """ # inp: [..., time, freq, channels] stfts_shape = shape_list(stfts) perm = list(range(len(stfts_shape))) perm = perm[:-3] + [perm[-1], perm[-3], perm[-2]] stfts = tf.transpose(stfts, perm=perm) stfts = tf.expand_dims(stfts, axis=-1) # [..., channels, time, freq, 1] melspecgrams = _stfts_to_melspecgrams(stfts, l2mel=l2mel, ifreq=ifreq, return_phase=return_phase) # [..., channels, time, freq, 2] melspecgrams_shape = shape_list(melspecgrams) perm = list(range(len(melspecgrams_shape))) perm = perm[:-4] + [perm[-3], perm[-2], perm[-4], perm[-1]] melspecgrams = tf.transpose(melspecgrams, perm=perm) # [..., time, freq, channels, 2] melspecgrams_shape = shape_list(melspecgrams) melspecgrams = tf.reshape(melspecgrams, melspecgrams_shape[:-2] + [ melspecgrams_shape[-2] * melspecgrams_shape[-1]]) # [..., time, freq, channels * 2] return melspecgrams
def call(self, x, training=None, mask=None): assert isinstance(x, list) if self.mode == 'provided_mean_var': x, beta, gamma = x beta_gamma = [beta, gamma] beta_gamma = tf.concat([beta_gamma], axis=-1) elif self.mode == 'mapped': x, beta_gamma = x beta = tf.matmul(beta_gamma, self.get_weight('beta_kernel', training=training)) beta = tf.nn.bias_add( beta, self.get_weight('beta_bias', training=training)) gamma = tf.matmul( beta_gamma, self.get_weight('gamma_kernel', training=training)) gamma = tf.nn.bias_add( gamma, self.get_weight('gamma_bias', training=training)) beta_gamma = tf.concat([beta, gamma], axis=-1) elif self.mode != 'provided_meanvar_fused': raise ValueError('Something is wrong') else: x, beta_gamma = x beta_gamma_shape = shape_list(beta_gamma) x_shape = shape_list(x) if len(beta_gamma_shape) != len(x_shape): beta_gamma = tf.reshape(beta_gamma, [ -1, ] + ([1] * (len(x.shape) - 2)) + [2, x_shape[-1]]) else: beta_gamma_shape_npa = np.array(beta_gamma_shape[1:-1]) x_shape_npa = np.array(x_shape[1:-1]) compatible = np.all( np.logical_or(beta_gamma_shape_npa == 1, beta_gamma_shape_npa == x_shape_npa)) if not compatible: size = np.where(beta_gamma_shape_npa == 1, beta_gamma_shape_npa, x_shape_npa).tolist() if len(beta_gamma_shape) == 4: beta_gamma = tf.image.resize(beta_gamma, size, method=self.method) elif len(beta_gamma_shape) == 3: beta_gamma = tf.squeeze(tf.image.resize( tf.expand_dims(beta_gamma, 1), [1] + size, method=self.method, antialias=self.antialias), axis=1) else: raise ValueError('Only works for 1D or 2D tensors') shape = [beta_gamma_shape[0]] + np.where( beta_gamma_shape_npa == 1, beta_gamma_shape_npa, x_shape_npa).tolist() + [2, x_shape[-1]] beta_gamma = tf.reshape(beta_gamma, shape) beta, gamma = tf.unstack(beta_gamma, axis=-2) return (x * gamma) + beta
def _compute_shape(self, inputs, seed): inputs_shape = shape_list(inputs) if self.channels > 0: inputs_shape[-1] = self.channels seed_shape = shape_list(seed) seed_shape = seed_shape + ([1] * (len(inputs_shape) - len(seed_shape))) random_shape = [a // b for a, b in zip(inputs_shape, seed_shape)] while random_shape[0] == 1: random_shape = random_shape[1:] return tf.stack(random_shape)
def call(self, inputs, training=None, **kwargs): y = None if type(inputs) == list: x, g = inputs d_x = shape_list(x)[-1] d_g = shape_list(g)[-1] if d_g // 2 == d_x: y, g = tf.split(g, num_or_size_splits=2, axis=-1) else: assert d_g == d_x else: x, g = tf.split(inputs, num_or_size_splits=2, axis=-1) return self.gating_function(x, g, y=y)
def _stfts_to_waves(stfts, n_fft=512, hop_length=256, discard_dc=True, pad_l=128, pad_r=128, hq=True): """Convert from complex stfts to waves. Args: stfts: Complex64/128 tensor of stft, shape [..., channels, time, freq, 1]. Returns: waves: Tensor of the waveform, shape [..., time, channels]. """ stfts = tf.squeeze(stfts, axis=-1) stfts_shape = shape_list(stfts) dc = 1 if discard_dc else 0 nyq = 1 - dc stfts = tf.pad(stfts, np.reshape(np.asarray([0, 0] * (len(stfts_shape) - 1)), (-1, 2)).tolist() + [[dc, nyq]]) if hq: stfts = tf.cast(stfts, tf.complex128) waves_resyn = tf.signal.inverse_stft( stfts=stfts, frame_length=n_fft, frame_step=hop_length, fft_length=n_fft, window_fn=inverse_stft_window_fn(frame_step=hop_length)) waves_resyn = tf.linalg.matrix_transpose(waves_resyn) if hq: waves_resyn = tf.cast(waves_resyn, tf.float32) crops = np.reshape(np.asarray([0, 0] * (len(stfts_shape) - 3)), (-1, 2)).tolist() + [[pad_l, pad_r], [0, 0]] return d_array_ops.crop(waves_resyn, crops)
def call(self, inputs, **kwargs): shape = shape_list(inputs) pad_len = self.divisor - (shape[self.axis] % self.divisor) paddings = [[0, 0]] * len(shape) paddings[self.axis] = [pad_len, 0 ] if self.location == 'start' else [0, pad_len] return tf.pad(inputs, paddings)
def call(self, inputs, training=None): input_shape = shape_list(inputs) if len(input_shape) > 1: fan_in = np.prod(input_shape[:-1]) # [kernel, kernel, fmaps_in, fmaps_out] or [in, out] he_std = self.gain / np.sqrt(fan_in) # He init runtime_coef = he_std * self.lrmul else: runtime_coef = self.lrmul return self.next_layer(inputs * runtime_coef, training=training)
def extract_and_split_2d(x, kernel_size=(3, 3), strides=(1, 1), dilation_rate=(1, 1), padding='same'): shape = shape_list(x) x, padding = pad_input_2d(x, padding, kernel_size=kernel_size, dilation_rate=dilation_rate) x = tf.image.extract_patches(x, [1, kernel_size[0], kernel_size[1], 1], [1, strides[0], strides[1], 1], [1, dilation_rate[0], dilation_rate[1], 1], padding=padding) shape_p = shape_list(x) x = tf.reshape(x, shape=shape_p[:-1] + [kernel_size[0] * kernel_size[1], shape[-1]]) return x
def masked_moments(x, axes, mask=None, keepdims=False, epsilon=1e-15): if mask is None: return tf.nn.moments(x, axes=axes, keepdims=keepdims) else: x_shape = shape_list(x) mask_shape = shape_list(mask) _mask = tf.reshape(tf.cast(mask, x.dtype), mask_shape + [1] * (len(x_shape) - len(mask_shape))) n_mask_indices = tf.reduce_sum(_mask, axis=axes, keepdims=True) _mean = tf.reduce_sum(x, axis=axes, keepdims=True) / tf.cast( tf.maximum(tf.cast(1, n_mask_indices.dtype), n_mask_indices), x.dtype) var = tf.reduce_sum(tf.math.squared_difference(x, _mean), axis=axes, keepdims=True) / tf.cast( tf.maximum(tf.cast(1, n_mask_indices.dtype), n_mask_indices - 1), x.dtype) return tf.reduce_sum(_mean, axis=axes, keepdims=keepdims), tf.reduce_sum(var, axis=axes, keepdims=keepdims)
def compute_mask(self, inputs, mask=None): if mask is not None: shape = shape_list(mask) pad_len = self.divisor - (shape[self.axis] % self.divisor) paddings = [[0, 0]] * len(shape) paddings[self.axis] = [ pad_len, 0 ] if self.location == 'start' else [0, pad_len] return tf.pad(mask, paddings, constant_values=False) return mask
def _stateless_random_normal(self, inputs, seed=None): inputs_shape = shape_list(inputs) if self.channels > 0: inputs_shape[-1] = self.channels if seed is None: return tf.random.normal(mean=self.mean, stddev=self.stddev, shape=inputs_shape) else: random_shape = self._compute_shape(inputs, seed) out = self._recurse_generate(seed, random_shape) return tf.reshape(out, inputs_shape)
def _upscale2d(x, strides, method, antialias=True, gain=1): x_shape = shape_list(x) ret_h = x_shape[1] * strides[0] ret_w = x_shape[2] * strides[1] # Apply gain. if gain != 1: x *= gain return tf.image.resize(x, size=[ret_h, ret_w], method=method, antialias=antialias)
def split_last_dimension(x, n): """Reshape x so that the last dimension becomes two dimensions. The first of these two dimensions is n. Args: x: a Tensor with shape [..., m] n: an integer. Returns: a Tensor with shape [..., n, m/n] """ x_shape = shape_list(x) m = x_shape[-1] if isinstance(m, int) and isinstance(n, int): assert m % n == 0 return tf.reshape(x, x_shape[:-1] + [n, m // n])
def waves_to_stfts(waves: tf.Tensor, n_fft=512, hop_length=256, discard_dc=True, pad_l=128, pad_r=128, hq=True) -> tf.Tensor: """Convert from waves to complex stfts. Args: waves: Tensor of the waveform, shape [..., time, channels]. Returns: stfts: Complex64 tensor of stft, shape [..., time, freq, channels]. """ stfts = _waves_to_stfts(waves, n_fft=n_fft, hop_length=hop_length, discard_dc=discard_dc, pad_l=pad_l, pad_r=pad_r, hq=hq) stfts = tf.squeeze(stfts, axis=-1) # [..., channels, time, freq] stfts_shape = shape_list(stfts) perm = list(range(len(stfts_shape))) perm = perm[:-3] + [perm[-2], perm[-1], perm[-3]] return tf.transpose(stfts, perm=perm)
def stfts_to_waves(stfts: tf.Tensor, n_fft=512, hop_length=256, discard_dc=True, pad_l=128, pad_r=128) -> tf.Tensor: """Convert from complex stfts to waves. Args: stfts: Complex64 tensor of stft, shape [..., time, freq, channels]. Returns: waves: Tensor of the waveform, shape [..., time, channels]. """ stfts_shape = shape_list(stfts) perm = list(range(len(stfts_shape))) perm = perm[:-3] + [perm[-1], perm[-3], perm[-2]] stfts = tf.transpose(stfts, perm=perm) # [..., channels, time, freq] stfts = tf.expand_dims(stfts, axis=-1) waves = _stfts_to_waves(stfts, n_fft=n_fft, hop_length=hop_length, discard_dc=discard_dc, pad_l=pad_l, pad_r=pad_r) # [..., channels, time, freq] return waves
def _recurse_generate(self, seed, shape): seed_shape = shape_list(seed) if len(seed_shape) == 0: if seed.dtype == tf.int64: seed = tf.bitcast(seed, tf.int32) elif seed.dtype == tf.int32: seed = tf.bitcast(seed, tf.int16) seed = tf.cast(seed, tf.int32) return tf.random.stateless_normal(mean=self.mean, stddev=self.stddev, seed=seed, shape=shape) else: return tf.stack( [self._recurse_generate(s, shape) for s in tf.unstack(seed)])
def frequency_masking(mel_spectrograms, frequency_masking_para: int = 100, frequency_mask_num: int = 1, roll_mask=None): """Spec augmentation Calculation Function. 'SpecAugment' have 3 steps for audio data augmentation. first step is time warping using Tensorflow's image_sparse_warp function. Second step is frequency masking, last step is time masking. Args: mel_spectrograms: Tensor of log magnitudes and possibly instantaneous frequencies, shape [..., time, freq, ch*(1/2)], mel scaling of frequencies. frequency_masking_para(int): Augmentation parameter, "frequency mask parameter F" If none, default = 100 for LibriSpeech. frequency_mask_num(int): number of frequency masking lines, "m_F". If none, default = 1 for LibriSpeech. Returns: mel_spectrograms: Tensor of log magnitudes and possibly instantaneous frequencies, shape [..., time, freq, ch*(1/2)], mel scaling of frequencies. """ # Step 2 : Frequency masking orig_dtype = mel_spectrograms.dtype fbank_size = shape_list(mel_spectrograms) _, n, n_mels, _ = fbank_size frequency_masking_para = min(frequency_masking_para, n_mels // 2) for i in range(frequency_mask_num): f = tf.random.uniform([], minval=0, maxval=frequency_masking_para, dtype=tf.int32) f0 = tf.random.uniform([], minval=0, maxval=n_mels - f, dtype=tf.int32) # warped_mel_spectrogram[f0:f0 + f, :] = 0 mask = tf.concat(( tf.ones(shape=(1, n, n_mels - f0 - f, 1)), tf.zeros(shape=(1, n, f, 1)), tf.ones(shape=(1, n, f0, 1)), ), 2) if roll_mask is not None: roll_mel_spectrograms = tf.roll(mel_spectrograms, roll_mask, axis=0) mel_spectrograms = (mel_spectrograms * mask) + (roll_mel_spectrograms * (1 - mask)) else: mel_spectrograms = mel_spectrograms * mask return tf.cast(mel_spectrograms, dtype=orig_dtype)
def sparse_warp(mel_spectrograms, time_warping_para: float = 80.): """Spec augmentation Calculation Function. 'SpecAugment' have 3 steps for audio data augmentation. first step is time warping using Tensorflow's image_sparse_warp function. Second step is frequency masking, last step is time masking. Args: mel_spectrograms: Tensor of log magnitudes and possibly instantaneous frequencies, shape [..., time, freq, ch*(1/2)], mel scaling of frequencies. time_warping_para(float): Augmentation parameter, "time warp parameter W". If none, default = 80 for LibriSpeech. Returns: mel_spectrograms: Tensor of log magnitudes and possibly instantaneous frequencies, shape [..., time, freq, ch*(1/2)], mel scaling of frequencies. """ fbank_size = shape_list(mel_spectrograms) _, n, n_mels, _ = fbank_size # Step 1 : Time warping # Image warping control point setting. # Source pt = tf.random.uniform( [], 0, n - (time_warping_para * 2), K.floatx()) + time_warping_para # radnom point along the time axis src_ctr_pt_freq = tf.cast(tf.range(n_mels // 2), K.floatx()) # control points on freq-axis src_ctr_pt_time = tf.ones_like( src_ctr_pt_freq) * pt # control points on time-axis src_ctr_pts = tf.stack((src_ctr_pt_time, src_ctr_pt_freq), -1) src_ctr_pts = tf.cast(src_ctr_pts, dtype=mel_spectrograms.dtype) # Destination w = tf.random.uniform([], -time_warping_para, time_warping_para, K.floatx()) # distance dest_ctr_pt_freq = src_ctr_pt_freq dest_ctr_pt_time = src_ctr_pt_time + w dest_ctr_pts = tf.stack((dest_ctr_pt_time, dest_ctr_pt_freq), -1) dest_ctr_pts = tf.cast(dest_ctr_pts, dtype=mel_spectrograms.dtype) # warp source_control_point_locations = tf.expand_dims(src_ctr_pts, 0) # (1, v//2, 2) dest_control_point_locations = tf.expand_dims(dest_ctr_pts, 0) # (1, v//2, 2) warped_image, _ = sparse_image_warp(mel_spectrograms, source_control_point_locations, dest_control_point_locations) return warped_image
def time_masking(mel_spectrograms, time_masking_para: int = 27, time_mask_num: int = 1, roll_mask=None): """Spec augmentation Calculation Function. 'SpecAugment' have 3 steps for audio data augmentation. first step is time warping using Tensorflow's image_sparse_warp function. Second step is frequency masking, last step is time masking. Args: mel_spectrograms(tf.Tensor): Tensor of log magnitudes and possibly instantaneous frequencies / phases, shape [..., time, freq, ch*(1/2)], mel scaling of frequencies. time_masking_para(int): Augmentation parameter, "time mask parameter T" If none, default = 27 for LibriSpeech. time_mask_num(int): number of time masking lines, "m_T". If none, default = 1 for LibriSpeech. Returns: mel_spectrogram: Tensor of log magnitudes and possibly instantaneous frequencies, shape [..., time, freq, ch*(1/2)], mel scaling of frequencies. """ orig_dtype = mel_spectrograms.dtype fbank_size = shape_list(mel_spectrograms) _, n, n_mels, _ = fbank_size # Step 3 : Time masking for i in range(time_mask_num): t = tf.random.uniform([], minval=0, maxval=time_masking_para, dtype=tf.int32) t0 = tf.random.uniform([], minval=0, maxval=n - t, dtype=tf.int32) # mel_spectrograms[:, t0:t0 + t] = 0 mask = tf.concat(( tf.ones(shape=(1, n - t0 - t, n_mels, 1)), tf.zeros(shape=(1, t, n_mels, 1)), tf.ones(shape=(1, t0, n_mels, 1)), ), 1) if roll_mask is not None: roll_mel_spectrograms = tf.roll(mel_spectrograms, roll_mask, axis=0) mel_spectrograms = (mel_spectrograms * mask) + (roll_mel_spectrograms * (1 - mask)) else: mel_spectrograms = mel_spectrograms * mask return tf.cast(mel_spectrograms, dtype=orig_dtype)
def time_delay_nn_1d(x, kernel, kernel_size, strides, dilation_rate, padding='valid'): shape = shape_list(x) x = tf.expand_dims(x, -1) x, padding = pad_input_2d(x, padding, kernel_size=(kernel_size, shape[-1]), dilation_rate=(dilation_rate, 1)) x = tf.image.extract_patches(x, sizes=[1, kernel_size, shape[-1], 1], strides=[1, strides, shape[-1], 1], rates=[1, dilation_rate, 1, 1], padding=padding) x = tf.squeeze(x, -2) x = tf.matmul(x, kernel) return x
def embedding_lookup(x, embedding_matrix=None, name='embedding_lookup', multiplier=1.0, symbol_dropout_rate=0.0): """Embed x of type int64 into dense vectors, reducing to max 4 dimensions.""" with tf.name_scope(name): # On the backwards pass, we want to convert the gradient from # an indexed-slices to a regular tensor before sending it back to the # parameter server. This avoids excess computation on the parameter server. if not tf.executing_eagerly(): embedding_matrix = convert_gradient_to_tensor(embedding_matrix) x = dropout_no_scaling(x, 1.0 - symbol_dropout_rate) emb_x = gather(embedding_matrix, x) if multiplier != 1.0: emb_x *= multiplier static_shape = shape_list(emb_x) if len(static_shape) < 5: return emb_x # assert len(static_shape) == 5 # If we had an extra channel dimension, assume it's 1, i.e. shape[3] == 1. return tf.squeeze(emb_x, 3)
def call(self, x, training=None): orig_dtype = x.dtype x = tf.cast(x, tf.float32) x_size = shape_list(x)[1:-1] x_size = np.where(np.array(list(self.pool_size)) == -1, 1, x_size).tolist() x_size = tf.convert_to_tensor(x_size) reduce_axes = np.where(np.array(self.pool_size) == -1)[0].tolist() if len(reduce_axes) == 2: x -= tf.reduce_mean(x, axis=reduce_axes, keepdims=True) x *= tf.math.rsqrt( tf.reduce_mean(tf.square(x), axis=reduce_axes, keepdims=True) + 1e-8) x = tf.cast(x, orig_dtype) return x pool_size_t = tf.convert_to_tensor(self.pool_size) pool_size_t = tf.maximum(pool_size_t, 1) pooled_size = x_size // pool_size_t def pool_reduce(x, dtype=tf.float32): if len(reduce_axes) > 0: x = tf.reduce_mean(x, axis=reduce_axes, keepdims=True) x = tf.cast( tf.image.resize(tf.image.resize(x, pooled_size, method=self.method, antialias=self.antialias), x_size, method=self.method, antialias=self.antialias), dtype) return x x -= pool_reduce(x, tf.float32) x *= tf.math.rsqrt(pool_reduce(tf.square(x), tf.float32) + 1e-8) x = tf.cast(x, orig_dtype) return x
def _waves_to_stfts(waves: tf.Tensor, n_fft=512, hop_length=256, discard_dc=True, pad_l=128, pad_r=128, hq=True) -> tf.Tensor: """Convert from waves to complex stfts. Args: waves: Tensor of the waveform, shape [..., time, channels]. Returns: stfts: Complex64 tensor of stft, shape [..., channels, time, freq, 1]. """ waves_shape = shape_list(waves) waves = tf.linalg.matrix_transpose(waves) # [..., channels, time] waves_padded = tf.pad(waves, np.reshape(np.asarray([0, 0] * (len(waves_shape) - 1)), (-1, 2)).tolist() + [[pad_l, pad_r]]) if hq: waves_padded = tf.cast(waves_padded, tf.float64) stfts = tf.signal.stft( waves_padded, window_fn=tf.signal.hann_window, frame_length=n_fft, frame_step=hop_length, fft_length=n_fft, pad_end=False) if discard_dc: stfts, dc = tf.split(stfts, num_or_size_splits=[n_fft // 2, 1], axis=-1) return tf.expand_dims(stfts, axis=-1)
def call(self, inputs, training=None, mask=None, **kwargs): training = self._get_training_value(training) x = inputs if mask is not None: x = tf.where(tf.expand_dims(mask, axis=-1), x, tf.zeros_like(x)) orig_dtype = x.dtype x = tf.cast(x, tf.float32) inputs_size = array_ops.size(inputs) axes = list(range(len(shape_list(x))))[:-1] training_value = tf_utils.constant_value(training) if training_value == False: # pylint: disable=singleton-comparison,g-explicit-bool-comparison mean, variance = self.moving_mean, self.moving_variance else: mean, variance = masked_moments(x, mask=mask, axes=axes, keepdims=False) mean = tf.squeeze(mean) variance = tf.squeeze(variance) moving_mean = self.moving_mean moving_variance = self.moving_variance mean = tf_utils.smart_cond( training, lambda: mean, lambda: ops.convert_to_tensor(moving_mean)) variance = tf_utils.smart_cond( training, lambda: variance, lambda: tf.convert_to_tensor(moving_variance)) def _do_update(var, value): """Compute the updates for mean and variance.""" return self._assign_moving_average(var, value, self.momentum, inputs_size) def mean_update(): true_branch = lambda: _do_update(self.moving_mean, mean) false_branch = lambda: self.moving_mean return tf_utils.smart_cond(training, true_branch, false_branch) def variance_update(): """Update the moving variance.""" true_branch = lambda: _do_update(self.moving_variance, variance ) false_branch = lambda: self.moving_variance return tf_utils.smart_cond(training, true_branch, false_branch) self.add_update(mean_update) self.add_update(variance_update) if self.scale: gamma = self.get_weight('gamma', training=training) else: gamma = None if self.center: beta = self.get_weight('beta', training=training) else: beta = None x = tf.nn.batch_normalization(x, mean=mean, variance=variance, scale=gamma, offset=beta, variance_epsilon=self.epsilon) x = tf.cast(x, orig_dtype) return x
def call(self, inputs, training=None, mask=None): if len(inputs) == 3: q, k, v = inputs else: raise ValueError() q_shape = shape_list(q) if mask is not None and self.attention_type != 'masked_local_attention_1d': q_mask = (1. - tf.cast(mask[0], tf.float32))[:, tf.newaxis, :, tf.newaxis] if self.mask_right and q_shape[1] is not None: # TODO: Reenable this somehow """ @tf.function def assert_mask_ok(mask_0_shape, mask_1_shape): assert mask_0_shape[1] == mask_1_shape[1] or mask_0_shape[1] == 1 assert_mask_ok(mask_0_shape, mask_1_shape) """ look_ahead_mask = self._create_look_ahead_mask(q_shape[1]) q_mask = tf.maximum(q_mask, look_ahead_mask) kv_mask = (1. - tf.cast(mask[1], tf.float32))[:, tf.newaxis, tf.newaxis, :] c_mask = tf.maximum(q_mask, kv_mask) # c_mask = tf.maximum(c_mask, look_ahead_mask) bias = c_mask * large_compatible_negative(k.dtype) else: if self.attention_type != 'masked_local_attention_1d' and self.mask_right: look_ahead_mask = self._create_look_ahead_mask(q_shape[1]) bias = look_ahead_mask * large_compatible_negative(k.dtype) else: bias = None r = None weights = None q = t2t_attention.split_heads(q, self.num_heads) k = t2t_attention.split_heads(k, self.num_heads_kv) v = t2t_attention.split_heads(v, self.num_heads_kv) if self.get_training_value(training): rate = self.dropout_rate else: rate = 0. if 'relative' in self.attention_type: key_embeddings = self.get_weight('key_embeddings', training=training) if self.add_relative_to_values: value_embeddings = self.get_weight('value_embeddings', training=training) else: value_embeddings = None if self.attention_type == 'unmasked_self_attention_relative': r, weights = t2t_attention.dot_product_unmasked_self_attention_relative_v2( q=q, k=k, v=v, bias=bias, key_leftright_embeddings=key_embeddings, value_leftright_embeddings=value_embeddings, dropout_rate=rate, max_relative_position=self.max_relative_position, heads_share_relative_embedding=self. heads_share_relative_embeddings, scaled=self.scaled) elif self.attention_type == 'masked_self_attention_relative': r, weights = t2t_attention.dot_product_self_attention_relative_v2( q=q, k=k, v=v, bias=bias, key_left_embedding=key_embeddings, value_left_embedding=value_embeddings, dropout_rate=rate, max_relative_position=self.max_relative_position, heads_share_relative_embedding=self. heads_share_relative_embeddings, scaled=self.scaled) else: if self.attention_type == 'unmasked_local_attention_1d': r, weights = t2t_attention.local_attention_1d( q=q, k=k, v=v, block_length=self.block_length, filter_width=self.filter_width, scaled=self.scaled) elif self.attention_type == 'masked_local_attention_1d': if mask is not None: attn_mask = tf.cast(mask[1], k.dtype) else: attn_mask = None r, weights = t2t_attention.masked_local_attention_1d( q=q, k=k, v=v, block_length=self.block_length, mask_right=self.mask_right, mask=attn_mask, dropout_rate=rate, scaled=self.scaled) elif self.attention_type == 'sparse_attention_truncated': r, loss, weights = t2t_attention.sparse_dot_product_attention_truncated( q=q, k=k, v=v, list_lsh=self.lsh_gates, mask_right=self.mask_right, scaled=self.scaled) self.add_loss(loss) else: r, weights = t2t_attention.dot_product_attention( q=q, k=k, v=v, bias=bias, dropout_rate=rate, scaled=self.scaled) r = t2t_attention.combine_heads(r) return r, weights
def get_bijector(self, x): event_shape_in = shape_list(x)[1:] chain = EventShapeAwareChain(event_shape_in, copy.copy(self.partial_bijectors)) return chain
def call(self, inputs, training=None, mask=None, cache=None, decode_loop_step=None, pad_q_to_kv=False): x = inputs q, q_mask = cm(self.q_layer, x, training=training, mask=mask) k, k_mask = cm(self.k_layer, x, training=training, mask=mask) v, v_mask = cm(self.v_layer, x, training=training, mask=mask) if cache is not None: # Combine cached keys and values with new keys and values. if cache["k"] is not None: # Update cache if decode_loop_step is not None: cache_k_shape = cache["k"].shape.as_list() indices = tf.reshape( tf.one_hot(decode_loop_step, cache_k_shape[1], dtype=k.dtype), [1, cache_k_shape[1], 1]) k = cache["k"] + k * indices if mask is not None: indices = tf.reshape( tf.one_hot(decode_loop_step, cache_k_shape[1], dtype=tf.float16), [1, cache_k_shape[1]]) k_mask = tf.logical_or( cache["k_mask"], (tf.cast(k_mask, tf.float16) * indices) > 0.) cache_v_shape = cache["v"].shape.as_list() indices = tf.reshape( tf.one_hot(decode_loop_step, cache_v_shape[1], dtype=v.dtype), [1, cache_v_shape[1], 1]) v = cache["v"] + v * indices if mask is not None: indices = tf.reshape( tf.one_hot(decode_loop_step, cache_v_shape[1], dtype=tf.float16), [1, cache_v_shape[1]]) v_mask = tf.logical_or( cache["v_mask"], (tf.cast(v_mask, tf.float16) * indices) > 0.) else: k = tf.concat([tf.cast(cache["k"], k.dtype), k], axis=1) v = tf.concat([tf.cast(cache["v"], v.dtype), v], axis=1) if mask is not None: k_mask = tf.concat( [tf.cast(cache["k_mask"], k_mask.dtype), k_mask], axis=1) v_mask = tf.concat( [tf.cast(cache["v_mask"], v_mask.dtype), v_mask], axis=1) # Update cache cache["k"] = k cache["v"] = v if mask is not None: cache["k_mask"] = k_mask cache["v_mask"] = v_mask q_shape = t2t_common.shape_list(q) kv_shape = t2t_common.shape_list(k) if pad_q_to_kv: if q_shape[1] != kv_shape[1]: if decode_loop_step is not None: q_prepad = decode_loop_step q_postpad = (kv_shape[1] - q_shape[1]) - decode_loop_step else: q_prepad = (kv_shape[1] - q_shape[1]) q_postpad = 0 q = tf.pad(q, paddings=[[0, 0], [q_prepad, q_postpad], [0, 0]]) if mask is not None: q_mask = tf.pad(q_mask, paddings=[[0, 0], [q_prepad, q_postpad]]) else: # This is just stupid autograph nonsense, ignore it if decode_loop_step is not None: q_prepad = decode_loop_step else: q_prepad = (kv_shape[1] - q_shape[1]) else: # This is just stupid autograph nonsense, ignore it if decode_loop_step is not None: q_prepad = decode_loop_step else: q_prepad = (kv_shape[1] - q_shape[1]) if mask is not None: mask = [q_mask, tf.logical_and(k_mask, v_mask)] x, weights = self.attention_layer([q, k, v], mask=mask, training=training) if not self.skip_out: x = self.out_layer(x, mask=mask, training=training) x_shape = t2t_common.shape_list(x) if pad_q_to_kv: if q_shape[1] != kv_shape[1]: if decode_loop_step is not None: x = tf.slice(x, [0, q_prepad, 0], [x_shape[0], 1, x_shape[2]]) else: x = tf.slice(x, [0, q_prepad, 0], [x_shape[0], q_shape[1], x_shape[2]]) if self.return_attn_weights: return x, weights return x
def crop(x, crops): crops = tf.convert_to_tensor(crops) begins, ends = tf.unstack(crops, axis=-1) shape = tf.convert_to_tensor(shape_list(x)) return tf.slice(x, begins, shape - (ends + begins))