class NormalizedLogMelExtractor(nn.Module): """ >>> x = torch.ones((10,1,100,257,2)) >>> NormalizedLogMelExtractor(40, 16000, 512, stft_scale_window=50)(x)[0].shape >>> NormalizedLogMelExtractor(40, 16000, 512, add_deltas=True, add_delta_deltas=True)(x)[0].shape """ def __init__( self, n_mels, sample_rate, fft_length, fmin=50, fmax=None, add_deltas=False, add_delta_deltas=False, mel_norm_statistics_axis='bt', mel_norm_eps=1e-5, # augmentation warping_fn=None, max_resample_rate=1., blur_sigma=0, blur_kernel_size=5, n_time_masks=0, max_masked_time_steps=70, max_masked_time_rate=.2, n_mel_masks=0, max_masked_mel_steps=20, max_masked_mel_rate=.2, max_noise_scale=0., ): super().__init__() self.mel_transform = MelTransform( n_mels=n_mels, sample_rate=sample_rate, fft_length=fft_length, fmin=fmin, fmax=fmax, log=True, warping_fn=warping_fn, ) self.add_deltas = add_deltas self.add_delta_deltas = add_delta_deltas self.norm = Norm( data_format='bcft', shape=(None, 1 + add_deltas + add_delta_deltas, n_mels, None), statistics_axis=mel_norm_statistics_axis, shift=True, scale=True, eps=mel_norm_eps, independent_axis=None, momentum=None, interpolation_factor=1., ) # augmentation if max_resample_rate > 1.: self.resampler = Resample( rate_sampling_fn=LogUniformSampler( scale=2*np.log(max_resample_rate) ) ) else: self.resampler = None if blur_sigma > 0: self.blur = GaussianBlur2d( kernel_size=blur_kernel_size, sigma_sampling_fn=TruncExponentialSampler( shift=.1, scale=blur_sigma, truncation=blur_kernel_size ) ) else: self.blur = None if n_time_masks > 0: self.time_masking = Mask( axis=-1, n_masks=n_time_masks, max_masked_steps=max_masked_time_steps, max_masked_rate=max_masked_time_rate, ) else: self.time_masking = None if n_mel_masks > 0: self.mel_masking = Mask( axis=-2, n_masks=n_mel_masks, max_masked_steps=max_masked_mel_steps, max_masked_rate=max_masked_mel_rate, ) else: self.mel_masking = None if max_noise_scale > 0.: self.noise = Noise(max_noise_scale) else: self.noise = None def forward(self, x, seq_len=None): with torch.no_grad(): x = self.mel_transform(torch.sum(x**2, dim=(-1,))).transpose(-2, -1) if self.blur is not None: x = self.blur(x) if self.add_deltas or self.add_delta_deltas: deltas = compute_deltas(x) if self.add_deltas: x = torch.cat((x, deltas), dim=1) if self.add_delta_deltas: delta_deltas = compute_deltas(deltas) x = torch.cat((x, delta_deltas), dim=1) x = self.norm(x, seq_len=seq_len) if self.time_masking is not None: x = self.time_masking(x, seq_len=seq_len) if self.mel_masking is not None: x = self.mel_masking(x) if self.noise is not None: # print(torch.std(x, dim=-1)) x = self.noise(x) return x, seq_len def inverse(self, x): return self.mel_transform.inverse( self.norm.inverse(x).transpose(-2, -1) )
class NormalizedLogMelExtractor(nn.Module): """ >>> x = torch.ones((10,1,100,257,2)) >>> NormalizedLogMelExtractor(40, 16000, 512, stft_scale_window=50)(x)[0].shape >>> NormalizedLogMelExtractor(40, 16000, 512, add_deltas=True, add_delta_deltas=True)(x)[0].shape """ def __init__( self, n_mels, sample_rate, fft_length, fmin=50, fmax=None, stft_norm_window=None, stft_norm_eps=1e-3, add_deltas=False, add_delta_deltas=False, statistics_axis='bt', scale=True, eps=1e-3, # augmentation scale_sigma=0., max_scale=4, mixup_prob=0., interpolated_mixup=False, warping_fn=None, max_resample_rate=1., blur_sigma=0, blur_kernel_size=5, n_time_masks=0, max_masked_time_steps=70, max_masked_time_rate=.2, n_mel_masks=0, max_masked_mel_steps=16, max_masked_mel_rate=.2, max_noise_scale=0., ): super().__init__() if stft_norm_window is not None: self.stft_norm = WindowNorm( stft_norm_window, data_format='bctf', shape=None, slide_axis='t', statistics_axis='f', independent_axis=None, shift=False, scale=True, eps=stft_norm_eps, ) else: self.stft_norm = None self.mel_transform = MelTransform( n_mels=n_mels, sample_rate=sample_rate, fft_length=fft_length, fmin=fmin, fmax=fmax, log=True, warping_fn=warping_fn, ) self.add_deltas = add_deltas self.add_delta_deltas = add_delta_deltas self.norm = Norm( data_format='bcft', shape=(None, 1 + add_deltas + add_delta_deltas, n_mels, None), statistics_axis=statistics_axis, scale=scale, eps=eps, independent_axis=None, momentum=None, ) # augmentation if scale_sigma > 0: self.scale = Scale( LogTruncNormalSampler(scale=scale_sigma, truncation=np.log(max_scale))) else: self.scale = None if mixup_prob > 0.: self.mixup = Mixup( p=mixup_prob, weight_sampling_fn=LogUniformSampler(scale=2 * np.log(2.)), interpolate=interpolated_mixup) else: self.mixup = None if max_resample_rate > 1.: self.resampler = Resample(rate_sampling_fn=LogUniformSampler( scale=2 * np.log(max_resample_rate))) else: self.resampler = None if blur_sigma > 0: self.blur = GaussianBlur2d( kernel_size=blur_kernel_size, sigma_sampling_fn=TruncExponentialSampler(shift=.1, scale=blur_sigma)) else: self.blur = None if n_time_masks > 0: self.time_masking = Mask( axis=-1, n_masks=n_time_masks, max_masked_steps=max_masked_time_steps, max_masked_rate=max_masked_time_rate, ) else: self.time_masking = None if n_mel_masks > 0: self.mel_masking = Mask( axis=-2, n_masks=n_mel_masks, max_masked_steps=max_masked_mel_steps, max_masked_rate=max_masked_mel_rate, ) else: self.mel_masking = None if max_noise_scale > 0.: self.noise = Noise(max_noise_scale) else: self.noise = None def forward(self, x, y=None, seq_len=None): with torch.no_grad(): if self.scale is not None: x = self.scale(x) if self.mixup is not None: if y is None: x, seq_len = self.mixup(x, seq_len=seq_len) else: x, y, seq_len = self.mixup(x, y, seq_len=seq_len) y = (y > 0).float() if self.stft_norm is not None: mag = torch.sqrt((x**2).sum(-1)) mag_ = self.stft_norm(mag, seq_len=seq_len) x = x * (mag_ / (mag + 1e-6)).unsqueeze(-1) x = self.mel_transform(torch.sum(x**2, dim=(-1, ))).transpose(-2, -1) if self.resampler is not None: if y is None or y.dim() == 2: x, seq_len = self.resampler(x, seq_len=seq_len) else: x, y, seq_len = self.resampler(x, y, seq_len=seq_len) y = (y > 0.5).float() if self.blur is not None: x = self.blur(x) if self.add_deltas or self.add_delta_deltas: deltas = compute_deltas(x) if self.add_deltas: x = torch.cat((x, deltas), dim=1) if self.add_delta_deltas: delta_deltas = compute_deltas(deltas) x = torch.cat((x, delta_deltas), dim=1) x = self.norm(x, seq_len=seq_len) if self.time_masking is not None: x = self.time_masking(x, seq_len=seq_len) if self.mel_masking is not None: x = self.mel_masking(x) if self.noise is not None: x = self.noise(x) return x, y, seq_len def inverse(self, x): return self.mel_transform.inverse( self.norm.inverse(x).transpose(-2, -1))