Ejemplo n.º 1
0
 def compute_spectral_loss(self, encoder, est_target, target, EPS=1e-8):
     batch_size = est_target.shape[0]
     spect_est_target = mag(encoder(est_target)).view(batch_size, -1)
     spect_target = mag(encoder(target)).view(batch_size, -1)
     linear_loss = self.norm1(spect_est_target - spect_target)
     log_loss = self.norm1(
         torch.log(spect_est_target + EPS) - torch.log(spect_target + EPS))
     return linear_loss + self.alpha * log_loss
Ejemplo n.º 2
0
def test_center_freq_correction(kernel_size, stride_factor):
    spec = torch.randn(2, kernel_size + 2, 50)
    stride = None if stride_factor is None else kernel_size // stride_factor
    new_spec = transforms.centerfreq_correction(spec,
                                                kernel_size=kernel_size,
                                                stride=stride)
    assert spec.shape == new_spec.shape
    assert_allclose(transforms.mag(spec), transforms.mag(new_spec))
Ejemplo n.º 3
0
 def common_step(self, batch, batch_nb, train=True):
     mix, clean = batch
     mix = unsqueeze_to_3d(mix)
     clean = unsqueeze_to_3d(clean)
     
     mix_tf = self.model.forward_encoder(mix)
     clean_tf = self.model.forward_encoder(clean)
     
     true_irm = torch.minimum(mag(clean_tf) / mag(mix_tf), torch.tensor(1).type_as(mix_tf))
     est_irm = self.model.forward_masker(mix_tf)
     loss = self.loss_func(est_irm, true_irm)
     return loss
Ejemplo n.º 4
0
def phasen_loss_wrapper(est_target, target):
    est_mag = mag(est_target)
    true_mag = mag(target)
    est_mag_comp = est_mag**0.3
    true_mag_comp = true_mag**0.3
    
    mag_loss = F.mse_loss(est_mag_comp, true_mag_comp)

    # scale the complex spectrograms' magniture to the power 0.3 as well
    true_comp_coeffs = (true_mag_comp/(1e-8+true_mag)).repeat(1,2,1)
    est_comp_coeffs = (est_mag_comp/(1e-8+est_mag)).repeat(1,2,1)
    phase_loss = F.mse_loss(est_target * est_comp_coeffs, target * true_comp_coeffs)
    
    return (mag_loss + phase_loss) / 2 
Ejemplo n.º 5
0
def test_pmsqe_pit(n_src, sample_rate):
    # Define supported STFT
    if sample_rate == 16000:
        stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256))
    else:
        stft = Encoder(STFTFB(kernel_size=256, n_filters=256, stride=128))
    # Usage by itself
    ref, est = torch.randn(2, n_src, 16000), torch.randn(2, n_src, 16000)
    ref_spec = transforms.mag(stft(ref))
    est_spec = transforms.mag(stft(est))
    loss_func = PITLossWrapper(SingleSrcPMSQE(sample_rate=sample_rate),
                               pit_from="pw_pt")
    # Assert forward ok.
    loss_func(est_spec, ref_spec)
Ejemplo n.º 6
0
 def unpack_data(self, batch, EPS=1e-8):
     mix, sources, noise = batch
     # Take only the first channel
     mix = mix[..., 0]
     sources = sources[..., 0]
     noise = noise[..., 0]
     noise = noise.unsqueeze(1)
     # Compute magnitude spectrograms and IRM
     src_mag_spec = mag(self.model.encoder(sources))
     noise_mag_spec = mag(self.model.encoder(noise))
     noise_mag_spec = noise_mag_spec.unsqueeze(1)
     real_mask = src_mag_spec / (noise_mag_spec +
                                 src_mag_spec.sum(1, keepdim=True) + EPS)
     # Get the src idx having the maximum energy
     binary_mask = real_mask.argmax(1)
     return mix, binary_mask, real_mask
Ejemplo n.º 7
0
    def forward_masker(self, tf_rep):
        batch_size = tf_rep.shape[0]
        log_mag = torch.log(mag(tf_rep)).unsqueeze(1)
                
        if self.has_scaler:
            l = log_mag.shape[-1]
            mean = self.scaler_mean.view(-1, 1).expand(-1, l)
            std = self.scaler_std.view(-1, 1).expand(-1, l)
            log_mag -= mean
            log_mag /= std
        
        padded = F.pad(log_mag, (self.padding, self.padding, 0, 0), mode='replicate')
        stacks = F.unfold(padded, (self.n_freq, self.padding * 2 + 1)) 
        new_batch = rearrange(stacks, 'n k l -> (n l) k')
        
        enc_out1 = torch.tanh(self.enc1(new_batch))
        enc_out2 = torch.tanh(self.enc2(enc_out1))
        enc_out3 = torch.tanh(self.enc3(enc_out2))
        mu, logvar = self.enc_mu_logvar(enc_out3), self.enc_mu_logvar(enc_out3)
        z = self.reparameterize(mu, logvar)
        dec_1 = self.dec1(z)
        unrolled_masks = self.dec2(dec_1)

        masks = rearrange(unrolled_masks, '(n l) k -> n k l', n=batch_size)
        return masks, mu, logvar 
Ejemplo n.º 8
0
 def dc_head_separate(self, x):
     """ Cluster embeddings to produce binary masks, output waveforms """
     kmeans = KMeans(n_clusters=self.masker.n_src)
     if len(x.shape) == 2:
         x = x.unsqueeze(1)
     tf_rep = self.encoder(x)
     mag_spec = mag(tf_rep)
     proj, mask_out = self.masker(mag_spec)
     active_bins = ebased_vad(mag_spec)
     active_proj = proj[active_bins.view(1, -1)]
     #
     bin_clusters = kmeans.fit_predict(active_proj.cpu().data.numpy())
     # Create binary masks
     est_mask_list = []
     for i in range(self.masker.n_src):
         # Add ones in all inactive bins in each mask.
         mask = ~active_bins
         mask[active_bins] = torch.from_numpy(
             (bin_clusters == i)).to(mask.device)
         est_mask_list.append(mask.float())  # Need float, not bool
     # Go back to time domain
     est_masks = torch.stack(est_mask_list, dim=1)
     masked = apply_mag_mask(tf_rep, est_masks)
     wavs = pad_x_to_y(self.decoder(masked), x)
     dic_out = dict(tfrep=tf_rep,
                    mask=mask_out,
                    masked_tfrep=masked,
                    proj=proj)
     return wavs, dic_out
Ejemplo n.º 9
0
 def common_step(self, batch, batch_nb, train=True):
     mix, clean = batch
     mix = unsqueeze_to_3d(mix)
     clean = unsqueeze_to_3d(clean)
     
     mix_tf = self.model.forward_encoder(mix)
     clean_tf = self.model.forward_encoder(clean)
     
     clean_pow = torch.pow(mag(clean_tf), 2)
     mix_pow = torch.pow(mag(mix_tf), 2)
     
     est_pow, mu, logvar = self.model.forward_vae_mu_logvar(mix_pow)
     
     loss, rec_loss, kl_loss = self.loss_func(est_pow, clean_pow, mu, logvar)
     self.log("rec_loss", rec_loss, logger=True)
     self.log("kl_loss", kl_loss, logger=True)
     return loss
Ejemplo n.º 10
0
 def unpack_data(self, batch, EPS=1e-8):
     mix, sources = batch
     # Compute magnitude spectrograms and IRM
     src_mag_spec = mag(self.model.encoder(sources))
     real_mask = src_mag_spec / (src_mag_spec.sum(1, keepdim=True) + EPS)
     # Get the src idx having the maximum energy
     binary_mask = real_mask.argmax(1)
     return mix, binary_mask, real_mask
Ejemplo n.º 11
0
def test_pmsqe(sample_rate):
    # Define supported STFT
    if sample_rate == 16000:
        stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256))
    else:
        stft = Encoder(STFTFB(kernel_size=256, n_filters=256, stride=128))
    # Usage by itself
    ref, est = torch.randn(2, 1, 16000), torch.randn(2, 1, 16000)
    ref_spec = transforms.mag(stft(ref))
    est_spec = transforms.mag(stft(est))
    loss_func = SingleSrcPMSQE(sample_rate=sample_rate)
    loss_value = loss_func(est_spec, ref_spec)
    # Assert output has shape (batch,)
    assert loss_value.shape[0] == ref.shape[0]
    # Assert support for transposed inputs.
    tr_loss_value = loss_func(est_spec.transpose(1, 2),
                              ref_spec.transpose(1, 2))
    assert_allclose(loss_value, tr_loss_value)
Ejemplo n.º 12
0
def test_griffinlim(fb_config, feed_istft, feed_angle):
    stft = Encoder(STFTFB(**fb_config))
    istft = None if not feed_istft else Decoder(STFTFB(**fb_config))
    wav = torch.randn(2, 1, 8000)
    spec = stft(wav)
    tf_mask = torch.sigmoid(torch.randn_like(spec))
    masked_spec = spec * tf_mask
    mag = transforms.mag(masked_spec, -2)
    angles = None if not feed_angle else transforms.angle(masked_spec, -2)
    griffin_lim(mag, stft, angles=angles, istft_dec=istft, n_iter=3)
Ejemplo n.º 13
0
def distance(estimate, target, is_complex=True):
    """Compute the average distance in the complex plane. Makes more sense
    when the network computes a complex mask.

    Args:
        estimate (torch.Tensor): Estimate complex spectrogram.
        target (torch.Tensor): Speech target complex spectrogram.
        is_complex (bool): Whether to compute the distance in the complex or
            the magnitude space.

    Returns:
        torch.Tensor the loss value, in a tensor of size 1.
    """
    if is_complex:
        # Take the difference in the complex plane and compute the squared norm
        # of the remaining vector.
        return mag(estimate - target).pow(2).mean()
    else:
        # Compute the mean difference between magnitudes.
        return (mag(estimate) - mag(target)).pow(2).mean()
Ejemplo n.º 14
0
 def common_step(self, batch, batch_nb, train=False):
     inputs, targets, masks = self.unpack_data(batch)
     embeddings, est_masks = self(inputs)
     spec = mag(self.model.encoder(inputs.unsqueeze(1)))
     if self.mask_mixture:
         est_masks = est_masks * spec.unsqueeze(1)
         masks = masks * spec.unsqueeze(1)
     loss, loss_dic = self.loss_func(
         embeddings, targets, est_src=est_masks, target_src=masks, mix_spec=spec
     )
     return loss, loss_dic
Ejemplo n.º 15
0
    def forward_masker(self, tf_rep):
        tf_rep = tf_rep.unsqueeze(1)
        inp = mag(tf_rep)

        padding = self.median_kernel_size // 2
        padded = F.pad(inp, (0, 0, padding, padding), mode='reflect')
        unfolded = F.unfold(padded, kernel_size=(self.median_kernel_size, 1))
        rearranged = rearrange(unfolded,
                               'b k (f t) -> b f t k',
                               t=inp.shape[-1])
        return rearranged.median(dim=-1)[0]
Ejemplo n.º 16
0
def test_angle_mag_recompostion(dim):
    """ Test complex --> (mag, angle) --> complex conversions"""
    max_tested_ndim = 4
    # Random tensor shape
    tensor_shape = [random.randint(1, 10) for _ in range(max_tested_ndim)]
    # Make sure complex dimension has even shape
    tensor_shape[dim] = 2 * tensor_shape[dim]
    complex_tensor = torch.randn(tensor_shape)
    phase = transforms.angle(complex_tensor, dim=dim)
    mag = transforms.mag(complex_tensor, dim=dim)
    tensor_back = transforms.from_magphase(mag, phase, dim=dim)
    assert_allclose(complex_tensor, tensor_back)
Ejemplo n.º 17
0
    def forward_masker(self, tf_rep):
        tf_rep = tf_rep.unsqueeze(1)

        if self.target == "TMS":
            input_rep = mag(tf_rep)
        else:
            input_rep = torch.cat(reim(tf_rep), dim=1)

        output_rep = self.masker(input_rep)
        if self.target != "TMS":
            output_rep = torch.cat(torch.chunk(output_rep, 2, dim=1), dim=-2)

        return output_rep.squeeze(1)
Ejemplo n.º 18
0
 def separate(self, x):
     """ Separate with mask-inference head, output waveforms """
     if len(x.shape) == 2:
         x = x.unsqueeze(1)
     tf_rep = self.encoder(x)
     proj, mask_out = self.masker(mag(tf_rep))
     masked = apply_mag_mask(tf_rep.unsqueeze(1), mask_out)
     wavs = torch_utils.pad_x_to_y(self.decoder(masked), x)
     dic_out = dict(tfrep=tf_rep,
                    mask=mask_out,
                    masked_tfrep=masked,
                    proj=proj)
     return wavs, dic_out
Ejemplo n.º 19
0
def test_pcen_forward(n_channels, batch_size):
    audio = torch.randn(batch_size, n_channels, 16000 * 10)

    fb = STFTFB(kernel_size=256, n_filters=256, stride=128)
    enc = Encoder(fb)
    tf_rep = enc(audio)
    mag_spec = transforms.mag(tf_rep)

    pcen = PCEN(n_channels=n_channels)
    energy = pcen(mag_spec)

    expected_shape = mag_spec.shape
    assert energy.shape == expected_shape
Ejemplo n.º 20
0
 def common_step(self, batch, batch_nb, train=True):
     mix, clean = batch
     mix = unsqueeze_to_3d(mix)
     clean = unsqueeze_to_3d(clean)
     
     mix_tf = self.model.forward_encoder(mix)
     clean_tf = self.model.forward_encoder(clean)
     
     clean_pow = torch.pow(mag(clean_tf), 2)
     
     # a HACK to not fiddle with datasets changing - training on clean data!
     est_pow, mu, logvar = self.model.forward_vae_mu_logvar(clean_pow)
     
     loss, rec_loss, kl_loss = self.loss_func(est_pow, clean_pow, mu, logvar)
     self.log("rec_loss", rec_loss, logger=True)
     self.log("kl_loss", kl_loss, logger=True)
     return loss
Ejemplo n.º 21
0
 def forward(self, x):
     if len(x.shape) == 2:
         x = x.unsqueeze(1)
     # Compute STFT
     tf_rep = self.encoder(x)
     # Estimate TF mask from STFT features : cat([re, im, mag])
     if self.is_complex:
         to_masker = magreim(tf_rep)
     else:
         to_masker = mag(tf_rep)
     # LSTM masker expects a feature dimension last (not like 1D conv)
     est_masks = self.masker(to_masker.transpose(1, 2)).transpose(1, 2)
     # Apply TF mask
     if self.is_complex:
         masked_tf_rep = apply_real_mask(tf_rep, est_masks)
     else:
         masked_tf_rep = apply_mag_mask(tf_rep, est_masks)
     return masked_tf_rep
Ejemplo n.º 22
0
def test_misi(fb_config, feed_istft, feed_angle):
    stft = Encoder(STFTFB(**fb_config))
    istft = None if not feed_istft else Decoder(STFTFB(**fb_config))
    n_src = 3
    # Create mixture
    wav = torch.randn(2, 1, 8000)
    spec = stft(wav).unsqueeze(1)
    # Create n_src masks on mixture spec and apply them
    shape = list(spec.shape)
    shape[1] *= n_src
    tf_mask = torch.sigmoid(torch.randn(*shape))
    masked_specs = spec * tf_mask
    # Separate mag and angle.
    mag = transforms.mag(masked_specs, -2)
    angles = None if not feed_angle else transforms.angle(masked_specs, -2)
    est_wavs = misi(wav, mag, stft, angles=angles, istft_dec=istft, n_iter=2)
    # We actually don't know the last dim because ISTFT(STFT()) cuts the end
    assert est_wavs.shape[:-1] == (2, n_src)
Ejemplo n.º 23
0
    def forward_masker(self, tf_rep):
        batch_size = tf_rep.shape[0]
        log_mag = torch.log(mag(tf_rep)).unsqueeze(1)

        if self.has_scaler:
            l = log_mag.shape[-1]
            mean = self.scaler_mean.view(-1, 1).expand(-1, l)
            std = self.scaler_std.view(-1, 1).expand(-1, l)
            log_mag -= mean
            log_mag /= std

        padded = F.pad(log_mag, (self.padding, self.padding, 0, 0),
                       mode='replicate')
        stacks = F.unfold(padded, (self.n_freq, self.padding * 2 + 1))

        new_batch = rearrange(stacks, 'n k l -> (n l) k')
        unrolled_masks = self.masker(new_batch)
        masks = rearrange(unrolled_masks, '(n l) k -> n k l', n=batch_size)
        return masks
Ejemplo n.º 24
0
    def forward_masker(self, tf_rep):
        """Estimates masks based on time-frequency representations.

        Args:
            tf_rep (torch.Tensor): Time-frequency representation in
                (batch, freq, seq).

        Returns:
            torch.Tensor: Estimated masks in (batch, freq, seq).
        """
        masker_input = tf_rep
        if self.input_type == "mag":
            masker_input = mag(masker_input)
        elif self.input_type == "cat":
            masker_input = magreim(masker_input)
        est_masks = self.masker(masker_input)
        if self.output_type == "mag":
            est_masks = est_masks.repeat(1, 2, 1)
        return est_masks
Ejemplo n.º 25
0
 def common_step(self, batch, batch_nb, train=True):
     mix, clean = batch
     mix = unsqueeze_to_3d(mix)
     clean = unsqueeze_to_3d(clean)
     
     mix_tf = self.model.forward_encoder(mix)
     clean_tf = self.model.forward_encoder(clean)
     
     model_output = self.model.forward_masker(mix_tf)
     
     if self.model.target == "cIRM":
         target_mask = perfect_cirm(mix_tf, clean_tf)
         loss = self.loss_func(model_output, target_mask)
         
     elif self.model.target == "TMS":
         loss = self.loss_func(model_output, mag(clean_tf))
     else:
         loss = self.loss_func(model_output, clean_tf)
     
     return loss
Ejemplo n.º 26
0
    def compute_scaler(self, data_iter):
        count = 0
        total_sum = torch.zeros(self.n_freq)
        total_sum_2 = torch.zeros(self.n_freq)

        for batch in tqdm(data_iter, 'Computing scaler'):
            mix, _ = batch
            mix = _unsqueeze_to_3d(mix)
            tf_rep = self.forward_encoder(mix)
            log_mag = torch.log(mag(tf_rep))

            total_sum += torch.sum(log_mag, dim=(0, 2))
            total_sum_2 += torch.sum(log_mag.pow(2), dim=(0, 2))
            count += log_mag.shape[0] * log_mag.shape[2]

        mean = total_sum / count
        variance = (total_sum_2 / count - mean.pow(2)) * (count / (count - 1))
        std = torch.sqrt(variance)

        self.scaler_mean = mean
        self.scaler_std = std
        self.has_scaler = True
Ejemplo n.º 27
0
 def forward(self, x):
     if len(x.shape) == 2:
         x = x.unsqueeze(1)
     tf_rep = self.encoder(x)
     final_proj, mask_out = self.masker(mag(tf_rep))
     return final_proj, mask_out
Ejemplo n.º 28
0
 def forward_masker(self, tf_rep):
     output, _, _ = self.forward_vae_mu_logvar(torch.pow(mag(tf_rep), 2))
     return from_magphase(torch.sqrt(output), angle(tf_rep))
Ejemplo n.º 29
0
 def forward_masker(self, tf_rep):
     return self.run_MCEM(mag(tf_rep))
Ejemplo n.º 30
0
def test_mag(encoder_list):
    for (enc, fb_dim) in encoder_list:
        tf_rep = enc(torch.randn(2, 1, 16000))  # [batch, freq, time]
        batch, freq, time = tf_rep.shape
        mag = transforms.mag(tf_rep, dim=1)
        assert mag.shape == (batch, freq // 2, time)