예제 #1
0
def melspecgrams_to_stfts(melspecgrams: tf.Tensor, mel2l, ifreq=True) -> tf.Tensor:
    """Converts melspecgrams to stfts.
        Args:
          melspecgrams: Tensor of log magnitudes and instantaneous frequencies,
            shape [..., time, freq, 2*channels], mel scaling of frequencies.
          mel2l: Mel to linear matrix, ie transposed linear to mel matrix
            @see
            tf.signal.linear_to_mel_weight_matrix
        Returns:
          specgrams: Tensor of log magnitudes and instantaneous frequencies,
            shape [..., time, freq, channels].
        """
    melspecgrams_shape = shape_list(melspecgrams)  # [..., time, freq, channels*2]
    melspecgrams = tf.reshape(melspecgrams,
                              melspecgrams_shape[:-1] + [melspecgrams_shape[-1] // 2,
                                                         2])  # [..., time, freq, channels, 2]
    perm = list(range(len(melspecgrams_shape) + 1))
    perm = perm[:-4] + [perm[-2], perm[-4], perm[-3], perm[-1]]
    melspecgrams = tf.transpose(melspecgrams, perm=perm)  # [..., channels, time, freq, 2]
    stfts = _melspecgrams_to_stfts(melspecgrams, mel2l=mel2l, ifreq=True)  # [..., channels, time, freq, 1]
    stfts = tf.squeeze(stfts, axis=-1)  # [..., channels, time, freq]
    stfts_shape = shape_list(stfts)
    perm = list(range(len(stfts_shape)))
    perm = perm[:-3] + [perm[-2], perm[-1], perm[-3]]
    stfts = tf.transpose(stfts, perm=perm)  # [..., time, freq, channels]
    return stfts
예제 #2
0
 def call(self, inputs, training=None, mask=None, **kwargs):
     x, n_s = inputs
     x_shape = shape_list(x)
     if self.needs_squeeze:
         n_s = tf.squeeze(n_s, axis=1)
     kernels = embedding_ops.embedding_lookup(
         n_s,
         tf.reshape(self.get_weight('kernel', training=training),
                    [self.n_kernels, x_shape[-1] * self.depth]),
         symbol_dropout_rate=0.)
     ks_shape = shape_list(kernels)
     kernels = tf.reshape(kernels,
                          [ks_shape[0]] + [1] * (self.extra_dims_needed) +
                          [x_shape[-1], self.depth])
     x = tf.matmul(x, tf.linalg.matrix_transpose(kernels), transpose_b=True)
     if self.use_bias:
         biases = embedding_ops.embedding_lookup(n_s,
                                                 self.get_weight(
                                                     'bias',
                                                     training=training),
                                                 symbol_dropout_rate=0.)
         biases = tf.reshape(biases, [ks_shape[0]] + [1] *
                             (self.extra_dims_needed + 1) + [self.depth])
         x += biases
     return self.activation(x)
예제 #3
0
def stfts_to_melspecgrams(stfts: tf.Tensor, l2mel, ifreq=True, return_phase=True) -> tf.Tensor:
    """Converts stfts to specgrams.
    Args:
      stfts: Complex64/Complex128 tensor of stft, shape [..., time, freq, channels].
    Returns:
      melspecgrams: Tensor of log magnitudes and instantaneous frequencies,
        shape [..., time, freq, 2*channels], mel scaling of frequencies.
    """
    # inp: [..., time, freq, channels]
    stfts_shape = shape_list(stfts)
    perm = list(range(len(stfts_shape)))
    perm = perm[:-3] + [perm[-1], perm[-3], perm[-2]]
    stfts = tf.transpose(stfts, perm=perm)
    stfts = tf.expand_dims(stfts, axis=-1)  # [..., channels, time, freq, 1]
    melspecgrams = _stfts_to_melspecgrams(stfts, l2mel=l2mel, ifreq=ifreq,
                                          return_phase=return_phase)  # [..., channels, time, freq, 2]

    melspecgrams_shape = shape_list(melspecgrams)
    perm = list(range(len(melspecgrams_shape)))
    perm = perm[:-4] + [perm[-3], perm[-2], perm[-4], perm[-1]]
    melspecgrams = tf.transpose(melspecgrams, perm=perm)  # [..., time, freq, channels, 2]
    melspecgrams_shape = shape_list(melspecgrams)
    melspecgrams = tf.reshape(melspecgrams, melspecgrams_shape[:-2] + [
        melspecgrams_shape[-2] * melspecgrams_shape[-1]])  # [..., time, freq, channels * 2]
    return melspecgrams
예제 #4
0
    def call(self, x, training=None, mask=None):
        assert isinstance(x, list)
        if self.mode == 'provided_mean_var':
            x, beta, gamma = x
            beta_gamma = [beta, gamma]
            beta_gamma = tf.concat([beta_gamma], axis=-1)
        elif self.mode == 'mapped':
            x, beta_gamma = x
            beta = tf.matmul(beta_gamma,
                             self.get_weight('beta_kernel', training=training))
            beta = tf.nn.bias_add(
                beta, self.get_weight('beta_bias', training=training))
            gamma = tf.matmul(
                beta_gamma, self.get_weight('gamma_kernel', training=training))
            gamma = tf.nn.bias_add(
                gamma, self.get_weight('gamma_bias', training=training))
            beta_gamma = tf.concat([beta, gamma], axis=-1)
        elif self.mode != 'provided_meanvar_fused':
            raise ValueError('Something is wrong')
        else:
            x, beta_gamma = x
        beta_gamma_shape = shape_list(beta_gamma)
        x_shape = shape_list(x)
        if len(beta_gamma_shape) != len(x_shape):
            beta_gamma = tf.reshape(beta_gamma, [
                -1,
            ] + ([1] * (len(x.shape) - 2)) + [2, x_shape[-1]])
        else:
            beta_gamma_shape_npa = np.array(beta_gamma_shape[1:-1])
            x_shape_npa = np.array(x_shape[1:-1])
            compatible = np.all(
                np.logical_or(beta_gamma_shape_npa == 1,
                              beta_gamma_shape_npa == x_shape_npa))
            if not compatible:
                size = np.where(beta_gamma_shape_npa == 1,
                                beta_gamma_shape_npa, x_shape_npa).tolist()
                if len(beta_gamma_shape) == 4:
                    beta_gamma = tf.image.resize(beta_gamma,
                                                 size,
                                                 method=self.method)
                elif len(beta_gamma_shape) == 3:
                    beta_gamma = tf.squeeze(tf.image.resize(
                        tf.expand_dims(beta_gamma, 1), [1] + size,
                        method=self.method,
                        antialias=self.antialias),
                                            axis=1)
                else:
                    raise ValueError('Only works for 1D or 2D tensors')

            shape = [beta_gamma_shape[0]] + np.where(
                beta_gamma_shape_npa == 1, beta_gamma_shape_npa,
                x_shape_npa).tolist() + [2, x_shape[-1]]
            beta_gamma = tf.reshape(beta_gamma, shape)
        beta, gamma = tf.unstack(beta_gamma, axis=-2)

        return (x * gamma) + beta
예제 #5
0
 def _compute_shape(self, inputs, seed):
     inputs_shape = shape_list(inputs)
     if self.channels > 0:
         inputs_shape[-1] = self.channels
     seed_shape = shape_list(seed)
     seed_shape = seed_shape + ([1] * (len(inputs_shape) - len(seed_shape)))
     random_shape = [a // b for a, b in zip(inputs_shape, seed_shape)]
     while random_shape[0] == 1:
         random_shape = random_shape[1:]
     return tf.stack(random_shape)
예제 #6
0
 def call(self, inputs, training=None, **kwargs):
     y = None
     if type(inputs) == list:
         x, g = inputs
         d_x = shape_list(x)[-1]
         d_g = shape_list(g)[-1]
         if d_g // 2 == d_x:
             y, g = tf.split(g, num_or_size_splits=2, axis=-1)
         else:
             assert d_g == d_x
     else:
         x, g = tf.split(inputs, num_or_size_splits=2, axis=-1)
     return self.gating_function(x, g, y=y)
예제 #7
0
def _stfts_to_waves(stfts, n_fft=512, hop_length=256, discard_dc=True, pad_l=128, pad_r=128, hq=True):
    """Convert from complex stfts to waves.
    Args:
      stfts: Complex64/128 tensor of stft, shape [..., channels, time, freq, 1].
    Returns:
      waves: Tensor of the waveform, shape [..., time, channels].
    """
    stfts = tf.squeeze(stfts, axis=-1)
    stfts_shape = shape_list(stfts)
    dc = 1 if discard_dc else 0
    nyq = 1 - dc
    stfts = tf.pad(stfts, np.reshape(np.asarray([0, 0] * (len(stfts_shape) - 1)), (-1, 2)).tolist() + [[dc, nyq]])
    if hq:
        stfts = tf.cast(stfts, tf.complex128)
    waves_resyn = tf.signal.inverse_stft(
        stfts=stfts,
        frame_length=n_fft,
        frame_step=hop_length,
        fft_length=n_fft,
        window_fn=inverse_stft_window_fn(frame_step=hop_length))
    waves_resyn = tf.linalg.matrix_transpose(waves_resyn)
    if hq:
        waves_resyn = tf.cast(waves_resyn, tf.float32)
    crops = np.reshape(np.asarray([0, 0] * (len(stfts_shape) - 3)), (-1, 2)).tolist() + [[pad_l, pad_r], [0, 0]]
    return d_array_ops.crop(waves_resyn, crops)
예제 #8
0
    def call(self, inputs, **kwargs):
        shape = shape_list(inputs)

        pad_len = self.divisor - (shape[self.axis] % self.divisor)
        paddings = [[0, 0]] * len(shape)
        paddings[self.axis] = [pad_len, 0
                               ] if self.location == 'start' else [0, pad_len]
        return tf.pad(inputs, paddings)
예제 #9
0
 def call(self, inputs, training=None):
     input_shape = shape_list(inputs)
     if len(input_shape) > 1:
         fan_in = np.prod(input_shape[:-1])  # [kernel, kernel, fmaps_in, fmaps_out] or [in, out]
         he_std = self.gain / np.sqrt(fan_in)  # He init
         runtime_coef = he_std * self.lrmul
     else:
         runtime_coef = self.lrmul
     return self.next_layer(inputs * runtime_coef, training=training)
def extract_and_split_2d(x,
                         kernel_size=(3, 3),
                         strides=(1, 1),
                         dilation_rate=(1, 1),
                         padding='same'):
    shape = shape_list(x)
    x, padding = pad_input_2d(x,
                              padding,
                              kernel_size=kernel_size,
                              dilation_rate=dilation_rate)
    x = tf.image.extract_patches(x, [1, kernel_size[0], kernel_size[1], 1],
                                 [1, strides[0], strides[1], 1],
                                 [1, dilation_rate[0], dilation_rate[1], 1],
                                 padding=padding)
    shape_p = shape_list(x)
    x = tf.reshape(x,
                   shape=shape_p[:-1] +
                   [kernel_size[0] * kernel_size[1], shape[-1]])
    return x
예제 #11
0
def masked_moments(x, axes, mask=None, keepdims=False, epsilon=1e-15):
    if mask is None:
        return tf.nn.moments(x, axes=axes, keepdims=keepdims)
    else:
        x_shape = shape_list(x)
        mask_shape = shape_list(mask)
        _mask = tf.reshape(tf.cast(mask, x.dtype),
                           mask_shape + [1] * (len(x_shape) - len(mask_shape)))
    n_mask_indices = tf.reduce_sum(_mask, axis=axes, keepdims=True)
    _mean = tf.reduce_sum(x, axis=axes, keepdims=True) / tf.cast(
        tf.maximum(tf.cast(1, n_mask_indices.dtype), n_mask_indices), x.dtype)
    var = tf.reduce_sum(tf.math.squared_difference(x, _mean),
                        axis=axes,
                        keepdims=True) / tf.cast(
                            tf.maximum(tf.cast(1, n_mask_indices.dtype),
                                       n_mask_indices - 1), x.dtype)
    return tf.reduce_sum(_mean, axis=axes,
                         keepdims=keepdims), tf.reduce_sum(var,
                                                           axis=axes,
                                                           keepdims=keepdims)
예제 #12
0
    def compute_mask(self, inputs, mask=None):
        if mask is not None:
            shape = shape_list(mask)

            pad_len = self.divisor - (shape[self.axis] % self.divisor)
            paddings = [[0, 0]] * len(shape)
            paddings[self.axis] = [
                pad_len, 0
            ] if self.location == 'start' else [0, pad_len]
            return tf.pad(mask, paddings, constant_values=False)
        return mask
예제 #13
0
 def _stateless_random_normal(self, inputs, seed=None):
     inputs_shape = shape_list(inputs)
     if self.channels > 0:
         inputs_shape[-1] = self.channels
     if seed is None:
         return tf.random.normal(mean=self.mean,
                                 stddev=self.stddev,
                                 shape=inputs_shape)
     else:
         random_shape = self._compute_shape(inputs, seed)
         out = self._recurse_generate(seed, random_shape)
         return tf.reshape(out, inputs_shape)
예제 #14
0
def _upscale2d(x, strides, method, antialias=True, gain=1):
    x_shape = shape_list(x)
    ret_h = x_shape[1] * strides[0]
    ret_w = x_shape[2] * strides[1]

    # Apply gain.
    if gain != 1:
        x *= gain

    return tf.image.resize(x,
                           size=[ret_h, ret_w],
                           method=method,
                           antialias=antialias)
예제 #15
0
def split_last_dimension(x, n):
    """Reshape x so that the last dimension becomes two dimensions.
    The first of these two dimensions is n.
    Args:
      x: a Tensor with shape [..., m]
      n: an integer.
    Returns:
      a Tensor with shape [..., n, m/n]
    """
    x_shape = shape_list(x)
    m = x_shape[-1]
    if isinstance(m, int) and isinstance(n, int):
        assert m % n == 0
    return tf.reshape(x, x_shape[:-1] + [n, m // n])
예제 #16
0
def waves_to_stfts(waves: tf.Tensor, n_fft=512, hop_length=256, discard_dc=True, pad_l=128, pad_r=128,
                   hq=True) -> tf.Tensor:
    """Convert from waves to complex stfts.
       Args:
         waves: Tensor of the waveform, shape [..., time, channels].
       Returns:
         stfts: Complex64 tensor of stft, shape [..., time, freq, channels].
       """
    stfts = _waves_to_stfts(waves, n_fft=n_fft, hop_length=hop_length, discard_dc=discard_dc, pad_l=pad_l, pad_r=pad_r,
                            hq=hq)
    stfts = tf.squeeze(stfts, axis=-1)  # [..., channels, time, freq]
    stfts_shape = shape_list(stfts)
    perm = list(range(len(stfts_shape)))
    perm = perm[:-3] + [perm[-2], perm[-1], perm[-3]]
    return tf.transpose(stfts, perm=perm)
예제 #17
0
def stfts_to_waves(stfts: tf.Tensor, n_fft=512, hop_length=256, discard_dc=True, pad_l=128, pad_r=128) -> tf.Tensor:
    """Convert from complex stfts to waves.
    Args:
      stfts: Complex64 tensor of stft, shape [..., time, freq, channels].
    Returns:
      waves: Tensor of the waveform, shape [..., time, channels].
    """
    stfts_shape = shape_list(stfts)
    perm = list(range(len(stfts_shape)))
    perm = perm[:-3] + [perm[-1], perm[-3], perm[-2]]
    stfts = tf.transpose(stfts, perm=perm)  # [..., channels, time, freq]
    stfts = tf.expand_dims(stfts, axis=-1)
    waves = _stfts_to_waves(stfts, n_fft=n_fft, hop_length=hop_length, discard_dc=discard_dc, pad_l=pad_l,
                            pad_r=pad_r)  # [..., channels, time, freq]
    return waves
예제 #18
0
 def _recurse_generate(self, seed, shape):
     seed_shape = shape_list(seed)
     if len(seed_shape) == 0:
         if seed.dtype == tf.int64:
             seed = tf.bitcast(seed, tf.int32)
         elif seed.dtype == tf.int32:
             seed = tf.bitcast(seed, tf.int16)
             seed = tf.cast(seed, tf.int32)
         return tf.random.stateless_normal(mean=self.mean,
                                           stddev=self.stddev,
                                           seed=seed,
                                           shape=shape)
     else:
         return tf.stack(
             [self._recurse_generate(s, shape) for s in tf.unstack(seed)])
예제 #19
0
def frequency_masking(mel_spectrograms,
                      frequency_masking_para: int = 100,
                      frequency_mask_num: int = 1,
                      roll_mask=None):
    """Spec augmentation Calculation Function.
    'SpecAugment' have 3 steps for audio data augmentation.
    first step is time warping using Tensorflow's image_sparse_warp function.
    Second step is frequency masking, last step is time masking.
    Args:
      mel_spectrograms: Tensor of log magnitudes and possibly instantaneous frequencies,
            shape [..., time, freq, ch*(1/2)], mel scaling of frequencies.
      frequency_masking_para(int): Augmentation parameter, "frequency mask parameter F"
        If none, default = 100 for LibriSpeech.
      frequency_mask_num(int): number of frequency masking lines, "m_F".
        If none, default = 1 for LibriSpeech.
    Returns:
      mel_spectrograms: Tensor of log magnitudes and possibly instantaneous frequencies,
            shape [..., time, freq, ch*(1/2)], mel scaling of frequencies.
    """
    # Step 2 : Frequency masking
    orig_dtype = mel_spectrograms.dtype
    fbank_size = shape_list(mel_spectrograms)
    _, n, n_mels, _ = fbank_size
    frequency_masking_para = min(frequency_masking_para, n_mels // 2)

    for i in range(frequency_mask_num):
        f = tf.random.uniform([],
                              minval=0,
                              maxval=frequency_masking_para,
                              dtype=tf.int32)
        f0 = tf.random.uniform([], minval=0, maxval=n_mels - f, dtype=tf.int32)

        # warped_mel_spectrogram[f0:f0 + f, :] = 0
        mask = tf.concat((
            tf.ones(shape=(1, n, n_mels - f0 - f, 1)),
            tf.zeros(shape=(1, n, f, 1)),
            tf.ones(shape=(1, n, f0, 1)),
        ), 2)
        if roll_mask is not None:
            roll_mel_spectrograms = tf.roll(mel_spectrograms,
                                            roll_mask,
                                            axis=0)
            mel_spectrograms = (mel_spectrograms *
                                mask) + (roll_mel_spectrograms * (1 - mask))
        else:
            mel_spectrograms = mel_spectrograms * mask
    return tf.cast(mel_spectrograms, dtype=orig_dtype)
예제 #20
0
def sparse_warp(mel_spectrograms, time_warping_para: float = 80.):
    """Spec augmentation Calculation Function.
    'SpecAugment' have 3 steps for audio data augmentation.
    first step is time warping using Tensorflow's image_sparse_warp function.
    Second step is frequency masking, last step is time masking.
    Args:
      mel_spectrograms: Tensor of log magnitudes and possibly instantaneous frequencies,
            shape [..., time, freq, ch*(1/2)], mel scaling of frequencies.
      time_warping_para(float): Augmentation parameter, "time warp parameter W".
        If none, default = 80 for LibriSpeech.
    Returns:
      mel_spectrograms: Tensor of log magnitudes and possibly instantaneous frequencies,
            shape [..., time, freq, ch*(1/2)], mel scaling of frequencies.
    """

    fbank_size = shape_list(mel_spectrograms)
    _, n, n_mels, _ = fbank_size
    # Step 1 : Time warping
    # Image warping control point setting.
    # Source
    pt = tf.random.uniform(
        [], 0, n - (time_warping_para * 2),
        K.floatx()) + time_warping_para  # radnom point along the time axis
    src_ctr_pt_freq = tf.cast(tf.range(n_mels // 2),
                              K.floatx())  # control points on freq-axis
    src_ctr_pt_time = tf.ones_like(
        src_ctr_pt_freq) * pt  # control points on time-axis
    src_ctr_pts = tf.stack((src_ctr_pt_time, src_ctr_pt_freq), -1)
    src_ctr_pts = tf.cast(src_ctr_pts, dtype=mel_spectrograms.dtype)

    # Destination
    w = tf.random.uniform([], -time_warping_para, time_warping_para,
                          K.floatx())  # distance
    dest_ctr_pt_freq = src_ctr_pt_freq
    dest_ctr_pt_time = src_ctr_pt_time + w
    dest_ctr_pts = tf.stack((dest_ctr_pt_time, dest_ctr_pt_freq), -1)
    dest_ctr_pts = tf.cast(dest_ctr_pts, dtype=mel_spectrograms.dtype)

    # warp
    source_control_point_locations = tf.expand_dims(src_ctr_pts,
                                                    0)  # (1, v//2, 2)
    dest_control_point_locations = tf.expand_dims(dest_ctr_pts,
                                                  0)  # (1, v//2, 2)
    warped_image, _ = sparse_image_warp(mel_spectrograms,
                                        source_control_point_locations,
                                        dest_control_point_locations)
    return warped_image
예제 #21
0
def time_masking(mel_spectrograms,
                 time_masking_para: int = 27,
                 time_mask_num: int = 1,
                 roll_mask=None):
    """Spec augmentation Calculation Function.
    'SpecAugment' have 3 steps for audio data augmentation.
    first step is time warping using Tensorflow's image_sparse_warp function.
    Second step is frequency masking, last step is time masking.
    Args:
      mel_spectrograms(tf.Tensor): Tensor of log magnitudes and possibly instantaneous frequencies / phases,
            shape [..., time, freq, ch*(1/2)], mel scaling of frequencies.
      time_masking_para(int): Augmentation parameter, "time mask parameter T"
        If none, default = 27 for LibriSpeech.
      time_mask_num(int): number of time masking lines, "m_T".
        If none, default = 1 for LibriSpeech.
    Returns:
      mel_spectrogram: Tensor of log magnitudes and possibly instantaneous frequencies,
            shape [..., time, freq, ch*(1/2)], mel scaling of frequencies.
    """
    orig_dtype = mel_spectrograms.dtype
    fbank_size = shape_list(mel_spectrograms)
    _, n, n_mels, _ = fbank_size
    # Step 3 : Time masking
    for i in range(time_mask_num):
        t = tf.random.uniform([],
                              minval=0,
                              maxval=time_masking_para,
                              dtype=tf.int32)
        t0 = tf.random.uniform([], minval=0, maxval=n - t, dtype=tf.int32)

        # mel_spectrograms[:, t0:t0 + t] = 0
        mask = tf.concat((
            tf.ones(shape=(1, n - t0 - t, n_mels, 1)),
            tf.zeros(shape=(1, t, n_mels, 1)),
            tf.ones(shape=(1, t0, n_mels, 1)),
        ), 1)
        if roll_mask is not None:
            roll_mel_spectrograms = tf.roll(mel_spectrograms,
                                            roll_mask,
                                            axis=0)
            mel_spectrograms = (mel_spectrograms *
                                mask) + (roll_mel_spectrograms * (1 - mask))
        else:
            mel_spectrograms = mel_spectrograms * mask

    return tf.cast(mel_spectrograms, dtype=orig_dtype)
예제 #22
0
def time_delay_nn_1d(x,
                     kernel,
                     kernel_size,
                     strides,
                     dilation_rate,
                     padding='valid'):
    shape = shape_list(x)
    x = tf.expand_dims(x, -1)
    x, padding = pad_input_2d(x,
                              padding,
                              kernel_size=(kernel_size, shape[-1]),
                              dilation_rate=(dilation_rate, 1))
    x = tf.image.extract_patches(x,
                                 sizes=[1, kernel_size, shape[-1], 1],
                                 strides=[1, strides, shape[-1], 1],
                                 rates=[1, dilation_rate, 1, 1],
                                 padding=padding)
    x = tf.squeeze(x, -2)
    x = tf.matmul(x, kernel)
    return x
예제 #23
0
def embedding_lookup(x,
                     embedding_matrix=None,
                     name='embedding_lookup',
                     multiplier=1.0,
                     symbol_dropout_rate=0.0):
    """Embed x of type int64 into dense vectors, reducing to max 4 dimensions."""
    with tf.name_scope(name):
        # On the backwards pass, we want to convert the gradient from
        # an indexed-slices to a regular tensor before sending it back to the
        # parameter server. This avoids excess computation on the parameter server.
        if not tf.executing_eagerly():
            embedding_matrix = convert_gradient_to_tensor(embedding_matrix)
        x = dropout_no_scaling(x, 1.0 - symbol_dropout_rate)
        emb_x = gather(embedding_matrix, x)
        if multiplier != 1.0:
            emb_x *= multiplier
        static_shape = shape_list(emb_x)
        if len(static_shape) < 5:
            return emb_x
        # assert len(static_shape) == 5
        # If we had an extra channel dimension, assume it's 1, i.e. shape[3] == 1.
        return tf.squeeze(emb_x, 3)
예제 #24
0
    def call(self, x, training=None):
        orig_dtype = x.dtype
        x = tf.cast(x, tf.float32)
        x_size = shape_list(x)[1:-1]
        x_size = np.where(np.array(list(self.pool_size)) == -1, 1,
                          x_size).tolist()
        x_size = tf.convert_to_tensor(x_size)
        reduce_axes = np.where(np.array(self.pool_size) == -1)[0].tolist()
        if len(reduce_axes) == 2:
            x -= tf.reduce_mean(x, axis=reduce_axes, keepdims=True)
            x *= tf.math.rsqrt(
                tf.reduce_mean(tf.square(x), axis=reduce_axes, keepdims=True) +
                1e-8)
            x = tf.cast(x, orig_dtype)
            return x
        pool_size_t = tf.convert_to_tensor(self.pool_size)
        pool_size_t = tf.maximum(pool_size_t, 1)
        pooled_size = x_size // pool_size_t

        def pool_reduce(x, dtype=tf.float32):
            if len(reduce_axes) > 0:
                x = tf.reduce_mean(x, axis=reduce_axes, keepdims=True)
            x = tf.cast(
                tf.image.resize(tf.image.resize(x,
                                                pooled_size,
                                                method=self.method,
                                                antialias=self.antialias),
                                x_size,
                                method=self.method,
                                antialias=self.antialias), dtype)
            return x

        x -= pool_reduce(x, tf.float32)
        x *= tf.math.rsqrt(pool_reduce(tf.square(x), tf.float32) + 1e-8)
        x = tf.cast(x, orig_dtype)
        return x
예제 #25
0
def _waves_to_stfts(waves: tf.Tensor, n_fft=512, hop_length=256, discard_dc=True, pad_l=128, pad_r=128,
                    hq=True) -> tf.Tensor:
    """Convert from waves to complex stfts.
    Args:
      waves: Tensor of the waveform, shape [..., time, channels].
    Returns:
      stfts: Complex64 tensor of stft, shape [..., channels, time, freq, 1].
    """
    waves_shape = shape_list(waves)
    waves = tf.linalg.matrix_transpose(waves)  # [..., channels, time]
    waves_padded = tf.pad(waves,
                          np.reshape(np.asarray([0, 0] * (len(waves_shape) - 1)), (-1, 2)).tolist() + [[pad_l, pad_r]])
    if hq:
        waves_padded = tf.cast(waves_padded, tf.float64)
    stfts = tf.signal.stft(
        waves_padded,
        window_fn=tf.signal.hann_window,
        frame_length=n_fft,
        frame_step=hop_length,
        fft_length=n_fft,
        pad_end=False)
    if discard_dc:
        stfts, dc = tf.split(stfts, num_or_size_splits=[n_fft // 2, 1], axis=-1)
    return tf.expand_dims(stfts, axis=-1)
예제 #26
0
    def call(self, inputs, training=None, mask=None, **kwargs):
        training = self._get_training_value(training)
        x = inputs
        if mask is not None:
            x = tf.where(tf.expand_dims(mask, axis=-1), x, tf.zeros_like(x))
        orig_dtype = x.dtype
        x = tf.cast(x, tf.float32)
        inputs_size = array_ops.size(inputs)
        axes = list(range(len(shape_list(x))))[:-1]
        training_value = tf_utils.constant_value(training)
        if training_value == False:  # pylint: disable=singleton-comparison,g-explicit-bool-comparison
            mean, variance = self.moving_mean, self.moving_variance
        else:
            mean, variance = masked_moments(x,
                                            mask=mask,
                                            axes=axes,
                                            keepdims=False)
            mean = tf.squeeze(mean)
            variance = tf.squeeze(variance)
            moving_mean = self.moving_mean
            moving_variance = self.moving_variance

            mean = tf_utils.smart_cond(
                training, lambda: mean,
                lambda: ops.convert_to_tensor(moving_mean))
            variance = tf_utils.smart_cond(
                training, lambda: variance,
                lambda: tf.convert_to_tensor(moving_variance))

            def _do_update(var, value):
                """Compute the updates for mean and variance."""
                return self._assign_moving_average(var, value, self.momentum,
                                                   inputs_size)

            def mean_update():
                true_branch = lambda: _do_update(self.moving_mean, mean)
                false_branch = lambda: self.moving_mean
                return tf_utils.smart_cond(training, true_branch, false_branch)

            def variance_update():
                """Update the moving variance."""

                true_branch = lambda: _do_update(self.moving_variance, variance
                                                 )

                false_branch = lambda: self.moving_variance
                return tf_utils.smart_cond(training, true_branch, false_branch)

            self.add_update(mean_update)
            self.add_update(variance_update)

        if self.scale:
            gamma = self.get_weight('gamma', training=training)
        else:
            gamma = None
        if self.center:
            beta = self.get_weight('beta', training=training)
        else:
            beta = None

        x = tf.nn.batch_normalization(x,
                                      mean=mean,
                                      variance=variance,
                                      scale=gamma,
                                      offset=beta,
                                      variance_epsilon=self.epsilon)
        x = tf.cast(x, orig_dtype)
        return x
예제 #27
0
    def call(self, inputs, training=None, mask=None):
        if len(inputs) == 3:
            q, k, v = inputs
        else:
            raise ValueError()

        q_shape = shape_list(q)

        if mask is not None and self.attention_type != 'masked_local_attention_1d':
            q_mask = (1. - tf.cast(mask[0], tf.float32))[:, tf.newaxis, :,
                                                         tf.newaxis]
            if self.mask_right and q_shape[1] is not None:
                # TODO: Reenable this somehow
                """
                @tf.function
                def assert_mask_ok(mask_0_shape, mask_1_shape):
                    assert mask_0_shape[1] == mask_1_shape[1] or mask_0_shape[1] == 1

                assert_mask_ok(mask_0_shape, mask_1_shape)
                """
                look_ahead_mask = self._create_look_ahead_mask(q_shape[1])
                q_mask = tf.maximum(q_mask, look_ahead_mask)
            kv_mask = (1. - tf.cast(mask[1], tf.float32))[:, tf.newaxis,
                                                          tf.newaxis, :]
            c_mask = tf.maximum(q_mask, kv_mask)

            # c_mask = tf.maximum(c_mask, look_ahead_mask)
            bias = c_mask * large_compatible_negative(k.dtype)
        else:
            if self.attention_type != 'masked_local_attention_1d' and self.mask_right:
                look_ahead_mask = self._create_look_ahead_mask(q_shape[1])
                bias = look_ahead_mask * large_compatible_negative(k.dtype)
            else:
                bias = None

        r = None
        weights = None

        q = t2t_attention.split_heads(q, self.num_heads)
        k = t2t_attention.split_heads(k, self.num_heads_kv)
        v = t2t_attention.split_heads(v, self.num_heads_kv)
        if self.get_training_value(training):
            rate = self.dropout_rate
        else:
            rate = 0.

        if 'relative' in self.attention_type:
            key_embeddings = self.get_weight('key_embeddings',
                                             training=training)
            if self.add_relative_to_values:
                value_embeddings = self.get_weight('value_embeddings',
                                                   training=training)
            else:
                value_embeddings = None
            if self.attention_type == 'unmasked_self_attention_relative':
                r, weights = t2t_attention.dot_product_unmasked_self_attention_relative_v2(
                    q=q,
                    k=k,
                    v=v,
                    bias=bias,
                    key_leftright_embeddings=key_embeddings,
                    value_leftright_embeddings=value_embeddings,
                    dropout_rate=rate,
                    max_relative_position=self.max_relative_position,
                    heads_share_relative_embedding=self.
                    heads_share_relative_embeddings,
                    scaled=self.scaled)
            elif self.attention_type == 'masked_self_attention_relative':
                r, weights = t2t_attention.dot_product_self_attention_relative_v2(
                    q=q,
                    k=k,
                    v=v,
                    bias=bias,
                    key_left_embedding=key_embeddings,
                    value_left_embedding=value_embeddings,
                    dropout_rate=rate,
                    max_relative_position=self.max_relative_position,
                    heads_share_relative_embedding=self.
                    heads_share_relative_embeddings,
                    scaled=self.scaled)
        else:
            if self.attention_type == 'unmasked_local_attention_1d':
                r, weights = t2t_attention.local_attention_1d(
                    q=q,
                    k=k,
                    v=v,
                    block_length=self.block_length,
                    filter_width=self.filter_width,
                    scaled=self.scaled)
            elif self.attention_type == 'masked_local_attention_1d':
                if mask is not None:
                    attn_mask = tf.cast(mask[1], k.dtype)
                else:
                    attn_mask = None
                r, weights = t2t_attention.masked_local_attention_1d(
                    q=q,
                    k=k,
                    v=v,
                    block_length=self.block_length,
                    mask_right=self.mask_right,
                    mask=attn_mask,
                    dropout_rate=rate,
                    scaled=self.scaled)
            elif self.attention_type == 'sparse_attention_truncated':
                r, loss, weights = t2t_attention.sparse_dot_product_attention_truncated(
                    q=q,
                    k=k,
                    v=v,
                    list_lsh=self.lsh_gates,
                    mask_right=self.mask_right,
                    scaled=self.scaled)
                self.add_loss(loss)
            else:
                r, weights = t2t_attention.dot_product_attention(
                    q=q,
                    k=k,
                    v=v,
                    bias=bias,
                    dropout_rate=rate,
                    scaled=self.scaled)

        r = t2t_attention.combine_heads(r)
        return r, weights
 def get_bijector(self, x):
     event_shape_in = shape_list(x)[1:]
     chain = EventShapeAwareChain(event_shape_in,
                                  copy.copy(self.partial_bijectors))
     return chain
예제 #29
0
    def call(self,
             inputs,
             training=None,
             mask=None,
             cache=None,
             decode_loop_step=None,
             pad_q_to_kv=False):
        x = inputs
        q, q_mask = cm(self.q_layer, x, training=training, mask=mask)
        k, k_mask = cm(self.k_layer, x, training=training, mask=mask)
        v, v_mask = cm(self.v_layer, x, training=training, mask=mask)
        if cache is not None:
            # Combine cached keys and values with new keys and values.
            if cache["k"] is not None:
                # Update cache
                if decode_loop_step is not None:

                    cache_k_shape = cache["k"].shape.as_list()
                    indices = tf.reshape(
                        tf.one_hot(decode_loop_step,
                                   cache_k_shape[1],
                                   dtype=k.dtype), [1, cache_k_shape[1], 1])
                    k = cache["k"] + k * indices
                    if mask is not None:
                        indices = tf.reshape(
                            tf.one_hot(decode_loop_step,
                                       cache_k_shape[1],
                                       dtype=tf.float16),
                            [1, cache_k_shape[1]])
                        k_mask = tf.logical_or(
                            cache["k_mask"],
                            (tf.cast(k_mask, tf.float16) * indices) > 0.)

                    cache_v_shape = cache["v"].shape.as_list()
                    indices = tf.reshape(
                        tf.one_hot(decode_loop_step,
                                   cache_v_shape[1],
                                   dtype=v.dtype), [1, cache_v_shape[1], 1])
                    v = cache["v"] + v * indices
                    if mask is not None:
                        indices = tf.reshape(
                            tf.one_hot(decode_loop_step,
                                       cache_v_shape[1],
                                       dtype=tf.float16),
                            [1, cache_v_shape[1]])
                        v_mask = tf.logical_or(
                            cache["v_mask"],
                            (tf.cast(v_mask, tf.float16) * indices) > 0.)
                else:
                    k = tf.concat([tf.cast(cache["k"], k.dtype), k], axis=1)
                    v = tf.concat([tf.cast(cache["v"], v.dtype), v], axis=1)
                    if mask is not None:
                        k_mask = tf.concat(
                            [tf.cast(cache["k_mask"], k_mask.dtype), k_mask],
                            axis=1)
                        v_mask = tf.concat(
                            [tf.cast(cache["v_mask"], v_mask.dtype), v_mask],
                            axis=1)

            # Update cache
            cache["k"] = k
            cache["v"] = v
            if mask is not None:
                cache["k_mask"] = k_mask
                cache["v_mask"] = v_mask

        q_shape = t2t_common.shape_list(q)
        kv_shape = t2t_common.shape_list(k)

        if pad_q_to_kv:
            if q_shape[1] != kv_shape[1]:
                if decode_loop_step is not None:
                    q_prepad = decode_loop_step
                    q_postpad = (kv_shape[1] - q_shape[1]) - decode_loop_step

                else:
                    q_prepad = (kv_shape[1] - q_shape[1])
                    q_postpad = 0
                q = tf.pad(q, paddings=[[0, 0], [q_prepad, q_postpad], [0, 0]])
                if mask is not None:
                    q_mask = tf.pad(q_mask,
                                    paddings=[[0, 0], [q_prepad, q_postpad]])
            else:
                # This is just stupid autograph nonsense, ignore it
                if decode_loop_step is not None:
                    q_prepad = decode_loop_step
                else:
                    q_prepad = (kv_shape[1] - q_shape[1])
        else:
            # This is just stupid autograph nonsense, ignore it
            if decode_loop_step is not None:
                q_prepad = decode_loop_step
            else:
                q_prepad = (kv_shape[1] - q_shape[1])

        if mask is not None:
            mask = [q_mask, tf.logical_and(k_mask, v_mask)]
        x, weights = self.attention_layer([q, k, v],
                                          mask=mask,
                                          training=training)
        if not self.skip_out:
            x = self.out_layer(x, mask=mask, training=training)
        x_shape = t2t_common.shape_list(x)
        if pad_q_to_kv:
            if q_shape[1] != kv_shape[1]:
                if decode_loop_step is not None:
                    x = tf.slice(x, [0, q_prepad, 0],
                                 [x_shape[0], 1, x_shape[2]])
                else:
                    x = tf.slice(x, [0, q_prepad, 0],
                                 [x_shape[0], q_shape[1], x_shape[2]])
        if self.return_attn_weights:
            return x, weights
        return x
예제 #30
0
def crop(x, crops):
    crops = tf.convert_to_tensor(crops)
    begins, ends = tf.unstack(crops, axis=-1)
    shape = tf.convert_to_tensor(shape_list(x))
    return tf.slice(x, begins, shape - (ends + begins))