def compute_spectral_loss(self, encoder, est_target, target, EPS=1e-8): batch_size = est_target.shape[0] spect_est_target = mag(encoder(est_target)).view(batch_size, -1) spect_target = mag(encoder(target)).view(batch_size, -1) linear_loss = self.norm1(spect_est_target - spect_target) log_loss = self.norm1( torch.log(spect_est_target + EPS) - torch.log(spect_target + EPS)) return linear_loss + self.alpha * log_loss
def test_center_freq_correction(kernel_size, stride_factor): spec = torch.randn(2, kernel_size + 2, 50) stride = None if stride_factor is None else kernel_size // stride_factor new_spec = transforms.centerfreq_correction(spec, kernel_size=kernel_size, stride=stride) assert spec.shape == new_spec.shape assert_allclose(transforms.mag(spec), transforms.mag(new_spec))
def common_step(self, batch, batch_nb, train=True): mix, clean = batch mix = unsqueeze_to_3d(mix) clean = unsqueeze_to_3d(clean) mix_tf = self.model.forward_encoder(mix) clean_tf = self.model.forward_encoder(clean) true_irm = torch.minimum(mag(clean_tf) / mag(mix_tf), torch.tensor(1).type_as(mix_tf)) est_irm = self.model.forward_masker(mix_tf) loss = self.loss_func(est_irm, true_irm) return loss
def phasen_loss_wrapper(est_target, target): est_mag = mag(est_target) true_mag = mag(target) est_mag_comp = est_mag**0.3 true_mag_comp = true_mag**0.3 mag_loss = F.mse_loss(est_mag_comp, true_mag_comp) # scale the complex spectrograms' magniture to the power 0.3 as well true_comp_coeffs = (true_mag_comp/(1e-8+true_mag)).repeat(1,2,1) est_comp_coeffs = (est_mag_comp/(1e-8+est_mag)).repeat(1,2,1) phase_loss = F.mse_loss(est_target * est_comp_coeffs, target * true_comp_coeffs) return (mag_loss + phase_loss) / 2
def test_pmsqe_pit(n_src, sample_rate): # Define supported STFT if sample_rate == 16000: stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256)) else: stft = Encoder(STFTFB(kernel_size=256, n_filters=256, stride=128)) # Usage by itself ref, est = torch.randn(2, n_src, 16000), torch.randn(2, n_src, 16000) ref_spec = transforms.mag(stft(ref)) est_spec = transforms.mag(stft(est)) loss_func = PITLossWrapper(SingleSrcPMSQE(sample_rate=sample_rate), pit_from="pw_pt") # Assert forward ok. loss_func(est_spec, ref_spec)
def unpack_data(self, batch, EPS=1e-8): mix, sources, noise = batch # Take only the first channel mix = mix[..., 0] sources = sources[..., 0] noise = noise[..., 0] noise = noise.unsqueeze(1) # Compute magnitude spectrograms and IRM src_mag_spec = mag(self.model.encoder(sources)) noise_mag_spec = mag(self.model.encoder(noise)) noise_mag_spec = noise_mag_spec.unsqueeze(1) real_mask = src_mag_spec / (noise_mag_spec + src_mag_spec.sum(1, keepdim=True) + EPS) # Get the src idx having the maximum energy binary_mask = real_mask.argmax(1) return mix, binary_mask, real_mask
def forward_masker(self, tf_rep): batch_size = tf_rep.shape[0] log_mag = torch.log(mag(tf_rep)).unsqueeze(1) if self.has_scaler: l = log_mag.shape[-1] mean = self.scaler_mean.view(-1, 1).expand(-1, l) std = self.scaler_std.view(-1, 1).expand(-1, l) log_mag -= mean log_mag /= std padded = F.pad(log_mag, (self.padding, self.padding, 0, 0), mode='replicate') stacks = F.unfold(padded, (self.n_freq, self.padding * 2 + 1)) new_batch = rearrange(stacks, 'n k l -> (n l) k') enc_out1 = torch.tanh(self.enc1(new_batch)) enc_out2 = torch.tanh(self.enc2(enc_out1)) enc_out3 = torch.tanh(self.enc3(enc_out2)) mu, logvar = self.enc_mu_logvar(enc_out3), self.enc_mu_logvar(enc_out3) z = self.reparameterize(mu, logvar) dec_1 = self.dec1(z) unrolled_masks = self.dec2(dec_1) masks = rearrange(unrolled_masks, '(n l) k -> n k l', n=batch_size) return masks, mu, logvar
def dc_head_separate(self, x): """ Cluster embeddings to produce binary masks, output waveforms """ kmeans = KMeans(n_clusters=self.masker.n_src) if len(x.shape) == 2: x = x.unsqueeze(1) tf_rep = self.encoder(x) mag_spec = mag(tf_rep) proj, mask_out = self.masker(mag_spec) active_bins = ebased_vad(mag_spec) active_proj = proj[active_bins.view(1, -1)] # bin_clusters = kmeans.fit_predict(active_proj.cpu().data.numpy()) # Create binary masks est_mask_list = [] for i in range(self.masker.n_src): # Add ones in all inactive bins in each mask. mask = ~active_bins mask[active_bins] = torch.from_numpy( (bin_clusters == i)).to(mask.device) est_mask_list.append(mask.float()) # Need float, not bool # Go back to time domain est_masks = torch.stack(est_mask_list, dim=1) masked = apply_mag_mask(tf_rep, est_masks) wavs = pad_x_to_y(self.decoder(masked), x) dic_out = dict(tfrep=tf_rep, mask=mask_out, masked_tfrep=masked, proj=proj) return wavs, dic_out
def common_step(self, batch, batch_nb, train=True): mix, clean = batch mix = unsqueeze_to_3d(mix) clean = unsqueeze_to_3d(clean) mix_tf = self.model.forward_encoder(mix) clean_tf = self.model.forward_encoder(clean) clean_pow = torch.pow(mag(clean_tf), 2) mix_pow = torch.pow(mag(mix_tf), 2) est_pow, mu, logvar = self.model.forward_vae_mu_logvar(mix_pow) loss, rec_loss, kl_loss = self.loss_func(est_pow, clean_pow, mu, logvar) self.log("rec_loss", rec_loss, logger=True) self.log("kl_loss", kl_loss, logger=True) return loss
def unpack_data(self, batch, EPS=1e-8): mix, sources = batch # Compute magnitude spectrograms and IRM src_mag_spec = mag(self.model.encoder(sources)) real_mask = src_mag_spec / (src_mag_spec.sum(1, keepdim=True) + EPS) # Get the src idx having the maximum energy binary_mask = real_mask.argmax(1) return mix, binary_mask, real_mask
def test_pmsqe(sample_rate): # Define supported STFT if sample_rate == 16000: stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256)) else: stft = Encoder(STFTFB(kernel_size=256, n_filters=256, stride=128)) # Usage by itself ref, est = torch.randn(2, 1, 16000), torch.randn(2, 1, 16000) ref_spec = transforms.mag(stft(ref)) est_spec = transforms.mag(stft(est)) loss_func = SingleSrcPMSQE(sample_rate=sample_rate) loss_value = loss_func(est_spec, ref_spec) # Assert output has shape (batch,) assert loss_value.shape[0] == ref.shape[0] # Assert support for transposed inputs. tr_loss_value = loss_func(est_spec.transpose(1, 2), ref_spec.transpose(1, 2)) assert_allclose(loss_value, tr_loss_value)
def test_griffinlim(fb_config, feed_istft, feed_angle): stft = Encoder(STFTFB(**fb_config)) istft = None if not feed_istft else Decoder(STFTFB(**fb_config)) wav = torch.randn(2, 1, 8000) spec = stft(wav) tf_mask = torch.sigmoid(torch.randn_like(spec)) masked_spec = spec * tf_mask mag = transforms.mag(masked_spec, -2) angles = None if not feed_angle else transforms.angle(masked_spec, -2) griffin_lim(mag, stft, angles=angles, istft_dec=istft, n_iter=3)
def distance(estimate, target, is_complex=True): """Compute the average distance in the complex plane. Makes more sense when the network computes a complex mask. Args: estimate (torch.Tensor): Estimate complex spectrogram. target (torch.Tensor): Speech target complex spectrogram. is_complex (bool): Whether to compute the distance in the complex or the magnitude space. Returns: torch.Tensor the loss value, in a tensor of size 1. """ if is_complex: # Take the difference in the complex plane and compute the squared norm # of the remaining vector. return mag(estimate - target).pow(2).mean() else: # Compute the mean difference between magnitudes. return (mag(estimate) - mag(target)).pow(2).mean()
def common_step(self, batch, batch_nb, train=False): inputs, targets, masks = self.unpack_data(batch) embeddings, est_masks = self(inputs) spec = mag(self.model.encoder(inputs.unsqueeze(1))) if self.mask_mixture: est_masks = est_masks * spec.unsqueeze(1) masks = masks * spec.unsqueeze(1) loss, loss_dic = self.loss_func( embeddings, targets, est_src=est_masks, target_src=masks, mix_spec=spec ) return loss, loss_dic
def forward_masker(self, tf_rep): tf_rep = tf_rep.unsqueeze(1) inp = mag(tf_rep) padding = self.median_kernel_size // 2 padded = F.pad(inp, (0, 0, padding, padding), mode='reflect') unfolded = F.unfold(padded, kernel_size=(self.median_kernel_size, 1)) rearranged = rearrange(unfolded, 'b k (f t) -> b f t k', t=inp.shape[-1]) return rearranged.median(dim=-1)[0]
def test_angle_mag_recompostion(dim): """ Test complex --> (mag, angle) --> complex conversions""" max_tested_ndim = 4 # Random tensor shape tensor_shape = [random.randint(1, 10) for _ in range(max_tested_ndim)] # Make sure complex dimension has even shape tensor_shape[dim] = 2 * tensor_shape[dim] complex_tensor = torch.randn(tensor_shape) phase = transforms.angle(complex_tensor, dim=dim) mag = transforms.mag(complex_tensor, dim=dim) tensor_back = transforms.from_magphase(mag, phase, dim=dim) assert_allclose(complex_tensor, tensor_back)
def forward_masker(self, tf_rep): tf_rep = tf_rep.unsqueeze(1) if self.target == "TMS": input_rep = mag(tf_rep) else: input_rep = torch.cat(reim(tf_rep), dim=1) output_rep = self.masker(input_rep) if self.target != "TMS": output_rep = torch.cat(torch.chunk(output_rep, 2, dim=1), dim=-2) return output_rep.squeeze(1)
def separate(self, x): """ Separate with mask-inference head, output waveforms """ if len(x.shape) == 2: x = x.unsqueeze(1) tf_rep = self.encoder(x) proj, mask_out = self.masker(mag(tf_rep)) masked = apply_mag_mask(tf_rep.unsqueeze(1), mask_out) wavs = torch_utils.pad_x_to_y(self.decoder(masked), x) dic_out = dict(tfrep=tf_rep, mask=mask_out, masked_tfrep=masked, proj=proj) return wavs, dic_out
def test_pcen_forward(n_channels, batch_size): audio = torch.randn(batch_size, n_channels, 16000 * 10) fb = STFTFB(kernel_size=256, n_filters=256, stride=128) enc = Encoder(fb) tf_rep = enc(audio) mag_spec = transforms.mag(tf_rep) pcen = PCEN(n_channels=n_channels) energy = pcen(mag_spec) expected_shape = mag_spec.shape assert energy.shape == expected_shape
def common_step(self, batch, batch_nb, train=True): mix, clean = batch mix = unsqueeze_to_3d(mix) clean = unsqueeze_to_3d(clean) mix_tf = self.model.forward_encoder(mix) clean_tf = self.model.forward_encoder(clean) clean_pow = torch.pow(mag(clean_tf), 2) # a HACK to not fiddle with datasets changing - training on clean data! est_pow, mu, logvar = self.model.forward_vae_mu_logvar(clean_pow) loss, rec_loss, kl_loss = self.loss_func(est_pow, clean_pow, mu, logvar) self.log("rec_loss", rec_loss, logger=True) self.log("kl_loss", kl_loss, logger=True) return loss
def forward(self, x): if len(x.shape) == 2: x = x.unsqueeze(1) # Compute STFT tf_rep = self.encoder(x) # Estimate TF mask from STFT features : cat([re, im, mag]) if self.is_complex: to_masker = magreim(tf_rep) else: to_masker = mag(tf_rep) # LSTM masker expects a feature dimension last (not like 1D conv) est_masks = self.masker(to_masker.transpose(1, 2)).transpose(1, 2) # Apply TF mask if self.is_complex: masked_tf_rep = apply_real_mask(tf_rep, est_masks) else: masked_tf_rep = apply_mag_mask(tf_rep, est_masks) return masked_tf_rep
def test_misi(fb_config, feed_istft, feed_angle): stft = Encoder(STFTFB(**fb_config)) istft = None if not feed_istft else Decoder(STFTFB(**fb_config)) n_src = 3 # Create mixture wav = torch.randn(2, 1, 8000) spec = stft(wav).unsqueeze(1) # Create n_src masks on mixture spec and apply them shape = list(spec.shape) shape[1] *= n_src tf_mask = torch.sigmoid(torch.randn(*shape)) masked_specs = spec * tf_mask # Separate mag and angle. mag = transforms.mag(masked_specs, -2) angles = None if not feed_angle else transforms.angle(masked_specs, -2) est_wavs = misi(wav, mag, stft, angles=angles, istft_dec=istft, n_iter=2) # We actually don't know the last dim because ISTFT(STFT()) cuts the end assert est_wavs.shape[:-1] == (2, n_src)
def forward_masker(self, tf_rep): batch_size = tf_rep.shape[0] log_mag = torch.log(mag(tf_rep)).unsqueeze(1) if self.has_scaler: l = log_mag.shape[-1] mean = self.scaler_mean.view(-1, 1).expand(-1, l) std = self.scaler_std.view(-1, 1).expand(-1, l) log_mag -= mean log_mag /= std padded = F.pad(log_mag, (self.padding, self.padding, 0, 0), mode='replicate') stacks = F.unfold(padded, (self.n_freq, self.padding * 2 + 1)) new_batch = rearrange(stacks, 'n k l -> (n l) k') unrolled_masks = self.masker(new_batch) masks = rearrange(unrolled_masks, '(n l) k -> n k l', n=batch_size) return masks
def forward_masker(self, tf_rep): """Estimates masks based on time-frequency representations. Args: tf_rep (torch.Tensor): Time-frequency representation in (batch, freq, seq). Returns: torch.Tensor: Estimated masks in (batch, freq, seq). """ masker_input = tf_rep if self.input_type == "mag": masker_input = mag(masker_input) elif self.input_type == "cat": masker_input = magreim(masker_input) est_masks = self.masker(masker_input) if self.output_type == "mag": est_masks = est_masks.repeat(1, 2, 1) return est_masks
def common_step(self, batch, batch_nb, train=True): mix, clean = batch mix = unsqueeze_to_3d(mix) clean = unsqueeze_to_3d(clean) mix_tf = self.model.forward_encoder(mix) clean_tf = self.model.forward_encoder(clean) model_output = self.model.forward_masker(mix_tf) if self.model.target == "cIRM": target_mask = perfect_cirm(mix_tf, clean_tf) loss = self.loss_func(model_output, target_mask) elif self.model.target == "TMS": loss = self.loss_func(model_output, mag(clean_tf)) else: loss = self.loss_func(model_output, clean_tf) return loss
def compute_scaler(self, data_iter): count = 0 total_sum = torch.zeros(self.n_freq) total_sum_2 = torch.zeros(self.n_freq) for batch in tqdm(data_iter, 'Computing scaler'): mix, _ = batch mix = _unsqueeze_to_3d(mix) tf_rep = self.forward_encoder(mix) log_mag = torch.log(mag(tf_rep)) total_sum += torch.sum(log_mag, dim=(0, 2)) total_sum_2 += torch.sum(log_mag.pow(2), dim=(0, 2)) count += log_mag.shape[0] * log_mag.shape[2] mean = total_sum / count variance = (total_sum_2 / count - mean.pow(2)) * (count / (count - 1)) std = torch.sqrt(variance) self.scaler_mean = mean self.scaler_std = std self.has_scaler = True
def forward(self, x): if len(x.shape) == 2: x = x.unsqueeze(1) tf_rep = self.encoder(x) final_proj, mask_out = self.masker(mag(tf_rep)) return final_proj, mask_out
def forward_masker(self, tf_rep): output, _, _ = self.forward_vae_mu_logvar(torch.pow(mag(tf_rep), 2)) return from_magphase(torch.sqrt(output), angle(tf_rep))
def forward_masker(self, tf_rep): return self.run_MCEM(mag(tf_rep))
def test_mag(encoder_list): for (enc, fb_dim) in encoder_list: tf_rep = enc(torch.randn(2, 1, 16000)) # [batch, freq, time] batch, freq, time = tf_rep.shape mag = transforms.mag(tf_rep, dim=1) assert mag.shape == (batch, freq // 2, time)