Exemplo n.º 1
0
 def __init__(self,
              wavenet,
              sample_rate,
              fft_length,
              n_mels,
              fmin=50,
              fmax=None):
     super().__init__(wavenet=wavenet,
                      sample_rate=sample_rate,
                      feature_key="features",
                      audio_key="audio_data")
     self.mel_transform = MelTransform(
         n_mels=n_mels,
         sample_rate=sample_rate,
         fft_length=fft_length,
         fmin=fmin,
         fmax=fmax,
     )
     self.in_norm = Norm(data_format='bcft',
                         shape=(None, 1, n_mels, None),
                         statistics_axis='bt',
                         scale=True,
                         independent_axis=None,
                         momentum=None,
                         interpolation_factor=1.)
Exemplo n.º 2
0
    def __init__(
        self,
        input_size,
        hidden_size,
        num_heads=1,
        bidirectional=False,
        cross_attention=False,
        norm='layer',
        norm_kwargs={},
        activation='relu',
    ):
        super().__init__()
        self.activation = ACTIVATION_FN_MAP[activation]()
        self.multi_head_self_attention = MultiHeadAttention(
            input_size, hidden_size, num_heads, bidirectional=bidirectional)
        self.cross_attention = cross_attention
        self.hidden = torch.nn.Linear(hidden_size, hidden_size)
        self.out = torch.nn.Linear(hidden_size, hidden_size)

        norm_kwargs = {
            "data_format": 'btc',
            "shape": (None, None, hidden_size),
            "statistics_axis": 'bt',
            **norm_kwargs
        }
        if norm is None:
            self.norm = None
        elif norm == 'batch':
            norm_kwargs['statistics_axis'] = 'bt'
        elif norm == 'layer':
            norm_kwargs['statistics_axis'] = 'tc'
            # ToDo: where is the difference between layer norm and instance norm?
        else:
            raise ValueError(f'{norm} normalization not known.')
        self.self_attention_norm = Norm(**norm_kwargs)
        self.output_norm = Norm(**norm_kwargs)

        if cross_attention:
            self.multi_head_cross_attention = MultiHeadAttention(
                hidden_size, hidden_size, num_heads, bidirectional=True)
            self.cross_attention_norm = Norm(**norm_kwargs)
Exemplo n.º 3
0
 def __init__(self, input_size, output_size):
     super().__init__()
     in_channels = 1
     self.in_norm = Norm(
         data_format='bcft',
         shape=(None, in_channels, input_size, None),
         statistics_axis='bt',
         scale=True,
         independent_axis=None,
         momentum=None,
     )
     self.cnn = CNN2d(
         in_channels=in_channels,
         out_channels=[
             16, 16, 32, 32, 64, 64, 128, 128, 256, 256, 512, 1024, output_size
         ],
         kernel_size=11 * [3] + [2, 1],
         pad_side=11 * ['both'] + 2 * [None],
         pool_size=[1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 1],
         norm='batch',
         activation_fn='relu',
     )
Exemplo n.º 4
0
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        dropout=0.,
        pad_side='both',
        dilation=1,
        stride=1,
        bias=True,
        norm=None,
        norm_kwargs={},
        activation_fn='relu',
        pre_activation=False,
        gated=False,
    ):
        """

        Args:
            in_channels:
            out_channels:
            kernel_size:
            dilation:
            stride:
            bias:
            dropout:
            norm: may be None or 'batch'
            activation_fn:
            pre_activation:
            gated:
        """
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        if self.is_2d():
            pad_side = to_pair(pad_side)
            kernel_size = to_pair(kernel_size)
            dilation = to_pair(dilation)
            stride = to_pair(stride)
        self.dropout = dropout
        self.pad_side = pad_side
        self.kernel_size = kernel_size
        self.dilation = dilation
        self.stride = stride
        self.activation_fn = ACTIVATION_FN_MAP[activation_fn]()
        self.pre_activation = pre_activation
        self.gated = gated

        self.conv = self.conv_cls(in_channels,
                                  out_channels,
                                  kernel_size=kernel_size,
                                  dilation=dilation,
                                  stride=stride,
                                  bias=bias)
        torch.nn.init.xavier_uniform_(self.conv.weight)
        if bias:
            torch.nn.init.zeros_(self.conv.bias)

        if norm is None:
            self.norm = None
        elif norm == 'batch':
            num_channels = in_channels if pre_activation else out_channels
            if self.is_2d():
                self.norm = Norm(data_format='bcft',
                                 shape=(None, num_channels, None, None),
                                 statistics_axis='bft',
                                 **norm_kwargs)
            else:
                self.norm = Norm(data_format='bct',
                                 shape=(None, num_channels, None),
                                 statistics_axis='bt',
                                 **norm_kwargs)
        else:
            raise ValueError(f'{norm} normalization not known.')

        if self.gated:
            self.gate_conv = self.conv_cls(in_channels,
                                           out_channels,
                                           kernel_size=kernel_size,
                                           dilation=dilation,
                                           stride=stride,
                                           bias=bias)
            torch.nn.init.xavier_uniform_(self.gate_conv.weight)
            if bias:
                torch.nn.init.zeros_(self.gate_conv.bias)
Exemplo n.º 5
0
    def __init__(
            self, n_mels, sample_rate, fft_length, fmin=50, fmax=None,
            add_deltas=False, add_delta_deltas=False,
            mel_norm_statistics_axis='bt', mel_norm_eps=1e-5,
            # augmentation
            warping_fn=None,
            max_resample_rate=1.,
            blur_sigma=0, blur_kernel_size=5,
            n_time_masks=0, max_masked_time_steps=70, max_masked_time_rate=.2,
            n_mel_masks=0, max_masked_mel_steps=20, max_masked_mel_rate=.2,
            max_noise_scale=0.,
    ):
        super().__init__()
        self.mel_transform = MelTransform(
            n_mels=n_mels, sample_rate=sample_rate, fft_length=fft_length,
            fmin=fmin, fmax=fmax, log=True, warping_fn=warping_fn,
        )
        self.add_deltas = add_deltas
        self.add_delta_deltas = add_delta_deltas
        self.norm = Norm(
            data_format='bcft',
            shape=(None, 1 + add_deltas + add_delta_deltas, n_mels, None),
            statistics_axis=mel_norm_statistics_axis,
            shift=True,
            scale=True,
            eps=mel_norm_eps,
            independent_axis=None,
            momentum=None,
            interpolation_factor=1.,
        )

        # augmentation
        if max_resample_rate > 1.:
            self.resampler = Resample(
                rate_sampling_fn=LogUniformSampler(
                    scale=2*np.log(max_resample_rate)
                )
            )
        else:
            self.resampler = None

        if blur_sigma > 0:
            self.blur = GaussianBlur2d(
                kernel_size=blur_kernel_size,
                sigma_sampling_fn=TruncExponentialSampler(
                    shift=.1, scale=blur_sigma, truncation=blur_kernel_size
                )
            )
        else:
            self.blur = None

        if n_time_masks > 0:
            self.time_masking = Mask(
                axis=-1, n_masks=n_time_masks,
                max_masked_steps=max_masked_time_steps,
                max_masked_rate=max_masked_time_rate,
            )
        else:
            self.time_masking = None

        if n_mel_masks > 0:
            self.mel_masking = Mask(
                axis=-2, n_masks=n_mel_masks,
                max_masked_steps=max_masked_mel_steps,
                max_masked_rate=max_masked_mel_rate,
            )
        else:
            self.mel_masking = None

        if max_noise_scale > 0.:
            self.noise = Noise(max_noise_scale)
        else:
            self.noise = None
Exemplo n.º 6
0
class NormalizedLogMelExtractor(nn.Module):
    """
    >>> x = torch.ones((10,1,100,257,2))
    >>> NormalizedLogMelExtractor(40, 16000, 512, stft_scale_window=50)(x)[0].shape
    >>> NormalizedLogMelExtractor(40, 16000, 512, add_deltas=True, add_delta_deltas=True)(x)[0].shape
    """
    def __init__(
            self, n_mels, sample_rate, fft_length, fmin=50, fmax=None,
            add_deltas=False, add_delta_deltas=False,
            mel_norm_statistics_axis='bt', mel_norm_eps=1e-5,
            # augmentation
            warping_fn=None,
            max_resample_rate=1.,
            blur_sigma=0, blur_kernel_size=5,
            n_time_masks=0, max_masked_time_steps=70, max_masked_time_rate=.2,
            n_mel_masks=0, max_masked_mel_steps=20, max_masked_mel_rate=.2,
            max_noise_scale=0.,
    ):
        super().__init__()
        self.mel_transform = MelTransform(
            n_mels=n_mels, sample_rate=sample_rate, fft_length=fft_length,
            fmin=fmin, fmax=fmax, log=True, warping_fn=warping_fn,
        )
        self.add_deltas = add_deltas
        self.add_delta_deltas = add_delta_deltas
        self.norm = Norm(
            data_format='bcft',
            shape=(None, 1 + add_deltas + add_delta_deltas, n_mels, None),
            statistics_axis=mel_norm_statistics_axis,
            shift=True,
            scale=True,
            eps=mel_norm_eps,
            independent_axis=None,
            momentum=None,
            interpolation_factor=1.,
        )

        # augmentation
        if max_resample_rate > 1.:
            self.resampler = Resample(
                rate_sampling_fn=LogUniformSampler(
                    scale=2*np.log(max_resample_rate)
                )
            )
        else:
            self.resampler = None

        if blur_sigma > 0:
            self.blur = GaussianBlur2d(
                kernel_size=blur_kernel_size,
                sigma_sampling_fn=TruncExponentialSampler(
                    shift=.1, scale=blur_sigma, truncation=blur_kernel_size
                )
            )
        else:
            self.blur = None

        if n_time_masks > 0:
            self.time_masking = Mask(
                axis=-1, n_masks=n_time_masks,
                max_masked_steps=max_masked_time_steps,
                max_masked_rate=max_masked_time_rate,
            )
        else:
            self.time_masking = None

        if n_mel_masks > 0:
            self.mel_masking = Mask(
                axis=-2, n_masks=n_mel_masks,
                max_masked_steps=max_masked_mel_steps,
                max_masked_rate=max_masked_mel_rate,
            )
        else:
            self.mel_masking = None

        if max_noise_scale > 0.:
            self.noise = Noise(max_noise_scale)
        else:
            self.noise = None

    def forward(self, x, seq_len=None):
        with torch.no_grad():

            x = self.mel_transform(torch.sum(x**2, dim=(-1,))).transpose(-2, -1)

            if self.blur is not None:
                x = self.blur(x)

            if self.add_deltas or self.add_delta_deltas:
                deltas = compute_deltas(x)
                if self.add_deltas:
                    x = torch.cat((x, deltas), dim=1)
                if self.add_delta_deltas:
                    delta_deltas = compute_deltas(deltas)
                    x = torch.cat((x, delta_deltas), dim=1)

            x = self.norm(x, seq_len=seq_len)

            if self.time_masking is not None:
                x = self.time_masking(x, seq_len=seq_len)
            if self.mel_masking is not None:
                x = self.mel_masking(x)

            if self.noise is not None:
                # print(torch.std(x, dim=-1))
                x = self.noise(x)

        return x, seq_len

    def inverse(self, x):
        return self.mel_transform.inverse(
            self.norm.inverse(x).transpose(-2, -1)
        )
Exemplo n.º 7
0
    def __init__(
        self,
        n_mels,
        sample_rate,
        fft_length,
        fmin=50,
        fmax=None,
        stft_norm_window=None,
        stft_norm_eps=1e-3,
        add_deltas=False,
        add_delta_deltas=False,
        statistics_axis='bt',
        scale=True,
        eps=1e-3,
        # augmentation
        scale_sigma=0.,
        max_scale=4,
        mixup_prob=0.,
        interpolated_mixup=False,
        warping_fn=None,
        max_resample_rate=1.,
        blur_sigma=0,
        blur_kernel_size=5,
        n_time_masks=0,
        max_masked_time_steps=70,
        max_masked_time_rate=.2,
        n_mel_masks=0,
        max_masked_mel_steps=16,
        max_masked_mel_rate=.2,
        max_noise_scale=0.,
    ):
        super().__init__()
        if stft_norm_window is not None:
            self.stft_norm = WindowNorm(
                stft_norm_window,
                data_format='bctf',
                shape=None,
                slide_axis='t',
                statistics_axis='f',
                independent_axis=None,
                shift=False,
                scale=True,
                eps=stft_norm_eps,
            )
        else:
            self.stft_norm = None
        self.mel_transform = MelTransform(
            n_mels=n_mels,
            sample_rate=sample_rate,
            fft_length=fft_length,
            fmin=fmin,
            fmax=fmax,
            log=True,
            warping_fn=warping_fn,
        )
        self.add_deltas = add_deltas
        self.add_delta_deltas = add_delta_deltas
        self.norm = Norm(
            data_format='bcft',
            shape=(None, 1 + add_deltas + add_delta_deltas, n_mels, None),
            statistics_axis=statistics_axis,
            scale=scale,
            eps=eps,
            independent_axis=None,
            momentum=None,
        )

        # augmentation
        if scale_sigma > 0:
            self.scale = Scale(
                LogTruncNormalSampler(scale=scale_sigma,
                                      truncation=np.log(max_scale)))
        else:
            self.scale = None

        if mixup_prob > 0.:
            self.mixup = Mixup(
                p=mixup_prob,
                weight_sampling_fn=LogUniformSampler(scale=2 * np.log(2.)),
                interpolate=interpolated_mixup)
        else:
            self.mixup = None

        if max_resample_rate > 1.:
            self.resampler = Resample(rate_sampling_fn=LogUniformSampler(
                scale=2 * np.log(max_resample_rate)))
        else:
            self.resampler = None

        if blur_sigma > 0:
            self.blur = GaussianBlur2d(
                kernel_size=blur_kernel_size,
                sigma_sampling_fn=TruncExponentialSampler(shift=.1,
                                                          scale=blur_sigma))
        else:
            self.blur = None

        if n_time_masks > 0:
            self.time_masking = Mask(
                axis=-1,
                n_masks=n_time_masks,
                max_masked_steps=max_masked_time_steps,
                max_masked_rate=max_masked_time_rate,
            )
        else:
            self.time_masking = None

        if n_mel_masks > 0:
            self.mel_masking = Mask(
                axis=-2,
                n_masks=n_mel_masks,
                max_masked_steps=max_masked_mel_steps,
                max_masked_rate=max_masked_mel_rate,
            )
        else:
            self.mel_masking = None

        if max_noise_scale > 0.:
            self.noise = Noise(max_noise_scale)
        else:
            self.noise = None
Exemplo n.º 8
0
class NormalizedLogMelExtractor(nn.Module):
    """
    >>> x = torch.ones((10,1,100,257,2))
    >>> NormalizedLogMelExtractor(40, 16000, 512, stft_scale_window=50)(x)[0].shape
    >>> NormalizedLogMelExtractor(40, 16000, 512, add_deltas=True, add_delta_deltas=True)(x)[0].shape
    """
    def __init__(
        self,
        n_mels,
        sample_rate,
        fft_length,
        fmin=50,
        fmax=None,
        stft_norm_window=None,
        stft_norm_eps=1e-3,
        add_deltas=False,
        add_delta_deltas=False,
        statistics_axis='bt',
        scale=True,
        eps=1e-3,
        # augmentation
        scale_sigma=0.,
        max_scale=4,
        mixup_prob=0.,
        interpolated_mixup=False,
        warping_fn=None,
        max_resample_rate=1.,
        blur_sigma=0,
        blur_kernel_size=5,
        n_time_masks=0,
        max_masked_time_steps=70,
        max_masked_time_rate=.2,
        n_mel_masks=0,
        max_masked_mel_steps=16,
        max_masked_mel_rate=.2,
        max_noise_scale=0.,
    ):
        super().__init__()
        if stft_norm_window is not None:
            self.stft_norm = WindowNorm(
                stft_norm_window,
                data_format='bctf',
                shape=None,
                slide_axis='t',
                statistics_axis='f',
                independent_axis=None,
                shift=False,
                scale=True,
                eps=stft_norm_eps,
            )
        else:
            self.stft_norm = None
        self.mel_transform = MelTransform(
            n_mels=n_mels,
            sample_rate=sample_rate,
            fft_length=fft_length,
            fmin=fmin,
            fmax=fmax,
            log=True,
            warping_fn=warping_fn,
        )
        self.add_deltas = add_deltas
        self.add_delta_deltas = add_delta_deltas
        self.norm = Norm(
            data_format='bcft',
            shape=(None, 1 + add_deltas + add_delta_deltas, n_mels, None),
            statistics_axis=statistics_axis,
            scale=scale,
            eps=eps,
            independent_axis=None,
            momentum=None,
        )

        # augmentation
        if scale_sigma > 0:
            self.scale = Scale(
                LogTruncNormalSampler(scale=scale_sigma,
                                      truncation=np.log(max_scale)))
        else:
            self.scale = None

        if mixup_prob > 0.:
            self.mixup = Mixup(
                p=mixup_prob,
                weight_sampling_fn=LogUniformSampler(scale=2 * np.log(2.)),
                interpolate=interpolated_mixup)
        else:
            self.mixup = None

        if max_resample_rate > 1.:
            self.resampler = Resample(rate_sampling_fn=LogUniformSampler(
                scale=2 * np.log(max_resample_rate)))
        else:
            self.resampler = None

        if blur_sigma > 0:
            self.blur = GaussianBlur2d(
                kernel_size=blur_kernel_size,
                sigma_sampling_fn=TruncExponentialSampler(shift=.1,
                                                          scale=blur_sigma))
        else:
            self.blur = None

        if n_time_masks > 0:
            self.time_masking = Mask(
                axis=-1,
                n_masks=n_time_masks,
                max_masked_steps=max_masked_time_steps,
                max_masked_rate=max_masked_time_rate,
            )
        else:
            self.time_masking = None

        if n_mel_masks > 0:
            self.mel_masking = Mask(
                axis=-2,
                n_masks=n_mel_masks,
                max_masked_steps=max_masked_mel_steps,
                max_masked_rate=max_masked_mel_rate,
            )
        else:
            self.mel_masking = None

        if max_noise_scale > 0.:
            self.noise = Noise(max_noise_scale)
        else:
            self.noise = None

    def forward(self, x, y=None, seq_len=None):
        with torch.no_grad():
            if self.scale is not None:
                x = self.scale(x)
            if self.mixup is not None:
                if y is None:
                    x, seq_len = self.mixup(x, seq_len=seq_len)
                else:
                    x, y, seq_len = self.mixup(x, y, seq_len=seq_len)
                    y = (y > 0).float()
            if self.stft_norm is not None:
                mag = torch.sqrt((x**2).sum(-1))
                mag_ = self.stft_norm(mag, seq_len=seq_len)
                x = x * (mag_ / (mag + 1e-6)).unsqueeze(-1)

            x = self.mel_transform(torch.sum(x**2,
                                             dim=(-1, ))).transpose(-2, -1)

            if self.resampler is not None:
                if y is None or y.dim() == 2:
                    x, seq_len = self.resampler(x, seq_len=seq_len)
                else:
                    x, y, seq_len = self.resampler(x, y, seq_len=seq_len)
                    y = (y > 0.5).float()

            if self.blur is not None:
                x = self.blur(x)

            if self.add_deltas or self.add_delta_deltas:
                deltas = compute_deltas(x)
                if self.add_deltas:
                    x = torch.cat((x, deltas), dim=1)
                if self.add_delta_deltas:
                    delta_deltas = compute_deltas(deltas)
                    x = torch.cat((x, delta_deltas), dim=1)

            x = self.norm(x, seq_len=seq_len)

            if self.time_masking is not None:
                x = self.time_masking(x, seq_len=seq_len)
            if self.mel_masking is not None:
                x = self.mel_masking(x)

            if self.noise is not None:
                x = self.noise(x)

        return x, y, seq_len

    def inverse(self, x):
        return self.mel_transform.inverse(
            self.norm.inverse(x).transpose(-2, -1))