def __init__(self, wavenet, sample_rate, fft_length, n_mels, fmin=50, fmax=None): super().__init__(wavenet=wavenet, sample_rate=sample_rate, feature_key="features", audio_key="audio_data") self.mel_transform = MelTransform( n_mels=n_mels, sample_rate=sample_rate, fft_length=fft_length, fmin=fmin, fmax=fmax, ) self.in_norm = Norm(data_format='bcft', shape=(None, 1, n_mels, None), statistics_axis='bt', scale=True, independent_axis=None, momentum=None, interpolation_factor=1.)
def __init__( self, input_size, hidden_size, num_heads=1, bidirectional=False, cross_attention=False, norm='layer', norm_kwargs={}, activation='relu', ): super().__init__() self.activation = ACTIVATION_FN_MAP[activation]() self.multi_head_self_attention = MultiHeadAttention( input_size, hidden_size, num_heads, bidirectional=bidirectional) self.cross_attention = cross_attention self.hidden = torch.nn.Linear(hidden_size, hidden_size) self.out = torch.nn.Linear(hidden_size, hidden_size) norm_kwargs = { "data_format": 'btc', "shape": (None, None, hidden_size), "statistics_axis": 'bt', **norm_kwargs } if norm is None: self.norm = None elif norm == 'batch': norm_kwargs['statistics_axis'] = 'bt' elif norm == 'layer': norm_kwargs['statistics_axis'] = 'tc' # ToDo: where is the difference between layer norm and instance norm? else: raise ValueError(f'{norm} normalization not known.') self.self_attention_norm = Norm(**norm_kwargs) self.output_norm = Norm(**norm_kwargs) if cross_attention: self.multi_head_cross_attention = MultiHeadAttention( hidden_size, hidden_size, num_heads, bidirectional=True) self.cross_attention_norm = Norm(**norm_kwargs)
def __init__(self, input_size, output_size): super().__init__() in_channels = 1 self.in_norm = Norm( data_format='bcft', shape=(None, in_channels, input_size, None), statistics_axis='bt', scale=True, independent_axis=None, momentum=None, ) self.cnn = CNN2d( in_channels=in_channels, out_channels=[ 16, 16, 32, 32, 64, 64, 128, 128, 256, 256, 512, 1024, output_size ], kernel_size=11 * [3] + [2, 1], pad_side=11 * ['both'] + 2 * [None], pool_size=[1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 1], norm='batch', activation_fn='relu', )
def __init__( self, in_channels, out_channels, kernel_size, dropout=0., pad_side='both', dilation=1, stride=1, bias=True, norm=None, norm_kwargs={}, activation_fn='relu', pre_activation=False, gated=False, ): """ Args: in_channels: out_channels: kernel_size: dilation: stride: bias: dropout: norm: may be None or 'batch' activation_fn: pre_activation: gated: """ super().__init__() self.in_channels = in_channels self.out_channels = out_channels if self.is_2d(): pad_side = to_pair(pad_side) kernel_size = to_pair(kernel_size) dilation = to_pair(dilation) stride = to_pair(stride) self.dropout = dropout self.pad_side = pad_side self.kernel_size = kernel_size self.dilation = dilation self.stride = stride self.activation_fn = ACTIVATION_FN_MAP[activation_fn]() self.pre_activation = pre_activation self.gated = gated self.conv = self.conv_cls(in_channels, out_channels, kernel_size=kernel_size, dilation=dilation, stride=stride, bias=bias) torch.nn.init.xavier_uniform_(self.conv.weight) if bias: torch.nn.init.zeros_(self.conv.bias) if norm is None: self.norm = None elif norm == 'batch': num_channels = in_channels if pre_activation else out_channels if self.is_2d(): self.norm = Norm(data_format='bcft', shape=(None, num_channels, None, None), statistics_axis='bft', **norm_kwargs) else: self.norm = Norm(data_format='bct', shape=(None, num_channels, None), statistics_axis='bt', **norm_kwargs) else: raise ValueError(f'{norm} normalization not known.') if self.gated: self.gate_conv = self.conv_cls(in_channels, out_channels, kernel_size=kernel_size, dilation=dilation, stride=stride, bias=bias) torch.nn.init.xavier_uniform_(self.gate_conv.weight) if bias: torch.nn.init.zeros_(self.gate_conv.bias)
def __init__( self, n_mels, sample_rate, fft_length, fmin=50, fmax=None, add_deltas=False, add_delta_deltas=False, mel_norm_statistics_axis='bt', mel_norm_eps=1e-5, # augmentation warping_fn=None, max_resample_rate=1., blur_sigma=0, blur_kernel_size=5, n_time_masks=0, max_masked_time_steps=70, max_masked_time_rate=.2, n_mel_masks=0, max_masked_mel_steps=20, max_masked_mel_rate=.2, max_noise_scale=0., ): super().__init__() self.mel_transform = MelTransform( n_mels=n_mels, sample_rate=sample_rate, fft_length=fft_length, fmin=fmin, fmax=fmax, log=True, warping_fn=warping_fn, ) self.add_deltas = add_deltas self.add_delta_deltas = add_delta_deltas self.norm = Norm( data_format='bcft', shape=(None, 1 + add_deltas + add_delta_deltas, n_mels, None), statistics_axis=mel_norm_statistics_axis, shift=True, scale=True, eps=mel_norm_eps, independent_axis=None, momentum=None, interpolation_factor=1., ) # augmentation if max_resample_rate > 1.: self.resampler = Resample( rate_sampling_fn=LogUniformSampler( scale=2*np.log(max_resample_rate) ) ) else: self.resampler = None if blur_sigma > 0: self.blur = GaussianBlur2d( kernel_size=blur_kernel_size, sigma_sampling_fn=TruncExponentialSampler( shift=.1, scale=blur_sigma, truncation=blur_kernel_size ) ) else: self.blur = None if n_time_masks > 0: self.time_masking = Mask( axis=-1, n_masks=n_time_masks, max_masked_steps=max_masked_time_steps, max_masked_rate=max_masked_time_rate, ) else: self.time_masking = None if n_mel_masks > 0: self.mel_masking = Mask( axis=-2, n_masks=n_mel_masks, max_masked_steps=max_masked_mel_steps, max_masked_rate=max_masked_mel_rate, ) else: self.mel_masking = None if max_noise_scale > 0.: self.noise = Noise(max_noise_scale) else: self.noise = None
class NormalizedLogMelExtractor(nn.Module): """ >>> x = torch.ones((10,1,100,257,2)) >>> NormalizedLogMelExtractor(40, 16000, 512, stft_scale_window=50)(x)[0].shape >>> NormalizedLogMelExtractor(40, 16000, 512, add_deltas=True, add_delta_deltas=True)(x)[0].shape """ def __init__( self, n_mels, sample_rate, fft_length, fmin=50, fmax=None, add_deltas=False, add_delta_deltas=False, mel_norm_statistics_axis='bt', mel_norm_eps=1e-5, # augmentation warping_fn=None, max_resample_rate=1., blur_sigma=0, blur_kernel_size=5, n_time_masks=0, max_masked_time_steps=70, max_masked_time_rate=.2, n_mel_masks=0, max_masked_mel_steps=20, max_masked_mel_rate=.2, max_noise_scale=0., ): super().__init__() self.mel_transform = MelTransform( n_mels=n_mels, sample_rate=sample_rate, fft_length=fft_length, fmin=fmin, fmax=fmax, log=True, warping_fn=warping_fn, ) self.add_deltas = add_deltas self.add_delta_deltas = add_delta_deltas self.norm = Norm( data_format='bcft', shape=(None, 1 + add_deltas + add_delta_deltas, n_mels, None), statistics_axis=mel_norm_statistics_axis, shift=True, scale=True, eps=mel_norm_eps, independent_axis=None, momentum=None, interpolation_factor=1., ) # augmentation if max_resample_rate > 1.: self.resampler = Resample( rate_sampling_fn=LogUniformSampler( scale=2*np.log(max_resample_rate) ) ) else: self.resampler = None if blur_sigma > 0: self.blur = GaussianBlur2d( kernel_size=blur_kernel_size, sigma_sampling_fn=TruncExponentialSampler( shift=.1, scale=blur_sigma, truncation=blur_kernel_size ) ) else: self.blur = None if n_time_masks > 0: self.time_masking = Mask( axis=-1, n_masks=n_time_masks, max_masked_steps=max_masked_time_steps, max_masked_rate=max_masked_time_rate, ) else: self.time_masking = None if n_mel_masks > 0: self.mel_masking = Mask( axis=-2, n_masks=n_mel_masks, max_masked_steps=max_masked_mel_steps, max_masked_rate=max_masked_mel_rate, ) else: self.mel_masking = None if max_noise_scale > 0.: self.noise = Noise(max_noise_scale) else: self.noise = None def forward(self, x, seq_len=None): with torch.no_grad(): x = self.mel_transform(torch.sum(x**2, dim=(-1,))).transpose(-2, -1) if self.blur is not None: x = self.blur(x) if self.add_deltas or self.add_delta_deltas: deltas = compute_deltas(x) if self.add_deltas: x = torch.cat((x, deltas), dim=1) if self.add_delta_deltas: delta_deltas = compute_deltas(deltas) x = torch.cat((x, delta_deltas), dim=1) x = self.norm(x, seq_len=seq_len) if self.time_masking is not None: x = self.time_masking(x, seq_len=seq_len) if self.mel_masking is not None: x = self.mel_masking(x) if self.noise is not None: # print(torch.std(x, dim=-1)) x = self.noise(x) return x, seq_len def inverse(self, x): return self.mel_transform.inverse( self.norm.inverse(x).transpose(-2, -1) )
def __init__( self, n_mels, sample_rate, fft_length, fmin=50, fmax=None, stft_norm_window=None, stft_norm_eps=1e-3, add_deltas=False, add_delta_deltas=False, statistics_axis='bt', scale=True, eps=1e-3, # augmentation scale_sigma=0., max_scale=4, mixup_prob=0., interpolated_mixup=False, warping_fn=None, max_resample_rate=1., blur_sigma=0, blur_kernel_size=5, n_time_masks=0, max_masked_time_steps=70, max_masked_time_rate=.2, n_mel_masks=0, max_masked_mel_steps=16, max_masked_mel_rate=.2, max_noise_scale=0., ): super().__init__() if stft_norm_window is not None: self.stft_norm = WindowNorm( stft_norm_window, data_format='bctf', shape=None, slide_axis='t', statistics_axis='f', independent_axis=None, shift=False, scale=True, eps=stft_norm_eps, ) else: self.stft_norm = None self.mel_transform = MelTransform( n_mels=n_mels, sample_rate=sample_rate, fft_length=fft_length, fmin=fmin, fmax=fmax, log=True, warping_fn=warping_fn, ) self.add_deltas = add_deltas self.add_delta_deltas = add_delta_deltas self.norm = Norm( data_format='bcft', shape=(None, 1 + add_deltas + add_delta_deltas, n_mels, None), statistics_axis=statistics_axis, scale=scale, eps=eps, independent_axis=None, momentum=None, ) # augmentation if scale_sigma > 0: self.scale = Scale( LogTruncNormalSampler(scale=scale_sigma, truncation=np.log(max_scale))) else: self.scale = None if mixup_prob > 0.: self.mixup = Mixup( p=mixup_prob, weight_sampling_fn=LogUniformSampler(scale=2 * np.log(2.)), interpolate=interpolated_mixup) else: self.mixup = None if max_resample_rate > 1.: self.resampler = Resample(rate_sampling_fn=LogUniformSampler( scale=2 * np.log(max_resample_rate))) else: self.resampler = None if blur_sigma > 0: self.blur = GaussianBlur2d( kernel_size=blur_kernel_size, sigma_sampling_fn=TruncExponentialSampler(shift=.1, scale=blur_sigma)) else: self.blur = None if n_time_masks > 0: self.time_masking = Mask( axis=-1, n_masks=n_time_masks, max_masked_steps=max_masked_time_steps, max_masked_rate=max_masked_time_rate, ) else: self.time_masking = None if n_mel_masks > 0: self.mel_masking = Mask( axis=-2, n_masks=n_mel_masks, max_masked_steps=max_masked_mel_steps, max_masked_rate=max_masked_mel_rate, ) else: self.mel_masking = None if max_noise_scale > 0.: self.noise = Noise(max_noise_scale) else: self.noise = None
class NormalizedLogMelExtractor(nn.Module): """ >>> x = torch.ones((10,1,100,257,2)) >>> NormalizedLogMelExtractor(40, 16000, 512, stft_scale_window=50)(x)[0].shape >>> NormalizedLogMelExtractor(40, 16000, 512, add_deltas=True, add_delta_deltas=True)(x)[0].shape """ def __init__( self, n_mels, sample_rate, fft_length, fmin=50, fmax=None, stft_norm_window=None, stft_norm_eps=1e-3, add_deltas=False, add_delta_deltas=False, statistics_axis='bt', scale=True, eps=1e-3, # augmentation scale_sigma=0., max_scale=4, mixup_prob=0., interpolated_mixup=False, warping_fn=None, max_resample_rate=1., blur_sigma=0, blur_kernel_size=5, n_time_masks=0, max_masked_time_steps=70, max_masked_time_rate=.2, n_mel_masks=0, max_masked_mel_steps=16, max_masked_mel_rate=.2, max_noise_scale=0., ): super().__init__() if stft_norm_window is not None: self.stft_norm = WindowNorm( stft_norm_window, data_format='bctf', shape=None, slide_axis='t', statistics_axis='f', independent_axis=None, shift=False, scale=True, eps=stft_norm_eps, ) else: self.stft_norm = None self.mel_transform = MelTransform( n_mels=n_mels, sample_rate=sample_rate, fft_length=fft_length, fmin=fmin, fmax=fmax, log=True, warping_fn=warping_fn, ) self.add_deltas = add_deltas self.add_delta_deltas = add_delta_deltas self.norm = Norm( data_format='bcft', shape=(None, 1 + add_deltas + add_delta_deltas, n_mels, None), statistics_axis=statistics_axis, scale=scale, eps=eps, independent_axis=None, momentum=None, ) # augmentation if scale_sigma > 0: self.scale = Scale( LogTruncNormalSampler(scale=scale_sigma, truncation=np.log(max_scale))) else: self.scale = None if mixup_prob > 0.: self.mixup = Mixup( p=mixup_prob, weight_sampling_fn=LogUniformSampler(scale=2 * np.log(2.)), interpolate=interpolated_mixup) else: self.mixup = None if max_resample_rate > 1.: self.resampler = Resample(rate_sampling_fn=LogUniformSampler( scale=2 * np.log(max_resample_rate))) else: self.resampler = None if blur_sigma > 0: self.blur = GaussianBlur2d( kernel_size=blur_kernel_size, sigma_sampling_fn=TruncExponentialSampler(shift=.1, scale=blur_sigma)) else: self.blur = None if n_time_masks > 0: self.time_masking = Mask( axis=-1, n_masks=n_time_masks, max_masked_steps=max_masked_time_steps, max_masked_rate=max_masked_time_rate, ) else: self.time_masking = None if n_mel_masks > 0: self.mel_masking = Mask( axis=-2, n_masks=n_mel_masks, max_masked_steps=max_masked_mel_steps, max_masked_rate=max_masked_mel_rate, ) else: self.mel_masking = None if max_noise_scale > 0.: self.noise = Noise(max_noise_scale) else: self.noise = None def forward(self, x, y=None, seq_len=None): with torch.no_grad(): if self.scale is not None: x = self.scale(x) if self.mixup is not None: if y is None: x, seq_len = self.mixup(x, seq_len=seq_len) else: x, y, seq_len = self.mixup(x, y, seq_len=seq_len) y = (y > 0).float() if self.stft_norm is not None: mag = torch.sqrt((x**2).sum(-1)) mag_ = self.stft_norm(mag, seq_len=seq_len) x = x * (mag_ / (mag + 1e-6)).unsqueeze(-1) x = self.mel_transform(torch.sum(x**2, dim=(-1, ))).transpose(-2, -1) if self.resampler is not None: if y is None or y.dim() == 2: x, seq_len = self.resampler(x, seq_len=seq_len) else: x, y, seq_len = self.resampler(x, y, seq_len=seq_len) y = (y > 0.5).float() if self.blur is not None: x = self.blur(x) if self.add_deltas or self.add_delta_deltas: deltas = compute_deltas(x) if self.add_deltas: x = torch.cat((x, deltas), dim=1) if self.add_delta_deltas: delta_deltas = compute_deltas(deltas) x = torch.cat((x, delta_deltas), dim=1) x = self.norm(x, seq_len=seq_len) if self.time_masking is not None: x = self.time_masking(x, seq_len=seq_len) if self.mel_masking is not None: x = self.mel_masking(x) if self.noise is not None: x = self.noise(x) return x, y, seq_len def inverse(self, x): return self.mel_transform.inverse( self.norm.inverse(x).transpose(-2, -1))