Esempio n. 1
0
 def compute_spectral_loss(self, encoder, est_target, target, EPS=1e-8):
     batch_size = est_target.shape[0]
     spect_est_target = take_mag(encoder(est_target)).view(batch_size, -1)
     spect_target = take_mag(encoder(target)).view(batch_size, -1)
     linear_loss = self.norm1(spect_est_target - spect_target)
     log_loss = self.norm1(
         torch.log(spect_est_target + EPS) - torch.log(spect_target + EPS))
     return linear_loss + self.alpha * log_loss
Esempio n. 2
0
def transform(mixture, sources):
    mix_mag = take_mag(mixture) + EPS
    src_mags = []
    for _src_ in sources:
        _src_mag_ = take_mag(_src_)
        src_mags.append(_src_mag_)
    spec_sum = torch.stack(src_mags, 0).sum(0) + EPS
    src_masks = [_src_mag / spec_sum for _src_mag in src_mags]
    return mix_mag, torch.stack(src_masks, 1)
Esempio n. 3
0
def test_pmsqe_pit(n_src, sample_rate):
    # Define supported STFT
    if sample_rate == 16000:
        stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256))
    else:
        stft = Encoder(STFTFB(kernel_size=256, n_filters=256, stride=128))
    # Usage by itself
    ref, est = torch.randn(2, n_src, 16000), torch.randn(2, n_src, 16000)
    ref_spec = transforms.take_mag(stft(ref))
    est_spec = transforms.take_mag(stft(est))
    loss_func = PITLossWrapper(SingleSrcPMSQE(sample_rate=sample_rate),
                               pit_from='pw_pt')
    # Assert forward ok.
    loss_value = loss_func(est_spec, ref_spec)
Esempio n. 4
0
 def unpack_data(self, batch):
     mix, sources, noise = batch
     # Take only the first channel
     mix = mix[..., 0]
     sources = sources[..., 0]
     noise = noise[..., 0]
     noise = noise.unsqueeze(1)
     # Compute magnitude spectrograms and IRM
     src_mag_spec = take_mag(self.model.encoder(sources))
     noise_mag_spec = take_mag(self.model.encoder(noise))
     noise_mag_spec = noise_mag_spec.unsqueeze(1)
     real_mask = src_mag_spec / (noise_mag_spec +
                                 src_mag_spec.sum(1, keepdim=True) + EPS)
     # Get the src idx having the maximum energy
     binary_mask = real_mask.argmax(1)
     return mix, binary_mask, real_mask
Esempio n. 5
0
 def dc_head_separate(self, x):
     """ Cluster embeddings to produce binary masks, output waveforms """
     kmeans = KMeans(n_clusters=self.masker.n_src)
     if len(x.shape) == 2:
         x = x.unsqueeze(1)
     tf_rep = self.encoder(x)
     mag_spec = take_mag(tf_rep)
     proj, mask_out = self.masker(mag_spec)
     active_bins = ebased_vad(mag_spec)
     active_proj = proj[active_bins.view(1, -1)]
     #
     bin_clusters = kmeans.fit_predict(active_proj.cpu().data.numpy())
     # Create binary masks
     est_mask_list = []
     for i in range(self.masker.n_src):
         # Add ones in all inactive bins in each mask.
         mask = ~active_bins
         mask[active_bins] = torch.from_numpy(
             (bin_clusters == i)).to(mask.device)
         est_mask_list.append(mask.float())  # Need float, not bool
     # Go back to time domain
     est_masks = torch.stack(est_mask_list, dim=1)
     masked = apply_mag_mask(tf_rep, est_masks)
     wavs = pad_x_to_y(self.decoder(masked), x)
     dic_out = dict(tfrep=tf_rep,
                    mask=mask_out,
                    masked_tfrep=masked,
                    proj=proj)
     return wavs, dic_out
Esempio n. 6
0
def compute_cost(model, batch):
    inputs, targets, masks = unpack_data(batch)
    spec = take_mag(enc(inputs.unsqueeze(1)))
    #spec = take_mag(enc(batch[1][:,0,:].unsqueeze(1)))
    spec = spec.cuda()
    est_targets = model(spec)
    #masks = torch.stack((masks[:,0,...], masks[:,0,...]),dim=1)
    #masks = masks[:,0,...].permute(0,2,1)
    #masks = masks.permute(0,2,1)
    masks = masks.cuda()
    #temp = torch.rand(5,129, 300)
    #temp = temp.cuda()
    #est_targets = model(temp)
    #masks = temp.permute(0,2,1)
    #torch.save((masks.data.cpu(), spec.data.cpu()), 'mask_spec.pt')
    #loss = torch.sqrt(torch.pow(est_targets[1] - masks, 2)+EPS).mean()
    #loss = torch.pow(est_targets[1] - masks, 2).mean()
    #loss = pairwise_mse(est_targets[1], masks).mean()
    loss = pit_loss(est_targets[1], masks)
    embedding = est_targets[0]

    vad_tf_mask = compute_vad(spec).cuda()
    targets = targets.cuda()
    dc_loss = deep_clustering_loss(embedding, targets, binary_mask=vad_tf_mask)
    #loss = torch.sqrt(torch.pow(est_targets[1] - spec.permute(0,2,1), 2)).mean()
    return dc_loss, loss
Esempio n. 7
0
 def unpack_data(self, batch):
     mix, sources = batch
     # Compute magnitude spectrograms and IRM
     src_mag_spec = take_mag(self.model.encoder(sources))
     real_mask = src_mag_spec / (src_mag_spec.sum(1, keepdim=True) + EPS)
     # Get the src idx having the maximum energy
     binary_mask = real_mask.argmax(1)
     return mix, binary_mask, real_mask
Esempio n. 8
0
def test_pmsqe(sample_rate):
    # Define supported STFT
    if sample_rate == 16000:
        stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256))
    else:
        stft = Encoder(STFTFB(kernel_size=256, n_filters=256, stride=128))
    # Usage by itself
    ref, est = torch.randn(2, 1, 16000), torch.randn(2, 1, 16000)
    ref_spec = transforms.take_mag(stft(ref))
    est_spec = transforms.take_mag(stft(est))
    loss_func = SingleSrcPMSQE(sample_rate=sample_rate)
    loss_value = loss_func(est_spec, ref_spec)
    # Assert output has shape (batch,)
    assert loss_value.shape[0] == ref.shape[0]
    # Assert support for transposed inputs.
    tr_loss_value = loss_func(est_spec.transpose(1, 2),
                              ref_spec.transpose(1, 2))
    assert_allclose(loss_value, tr_loss_value)
Esempio n. 9
0
 def common_step(self, batch, batch_nb, train=False):
     inputs, targets, masks = self.unpack_data(batch)
     embeddings, est_masks = self(inputs)
     spec = take_mag(self.model.encoder(inputs.unsqueeze(1)))
     if self.mask_mixture:
         est_masks = est_masks * spec.unsqueeze(1)
         masks = masks * spec.unsqueeze(1)
     loss, loss_dic = self.loss_func(embeddings, targets, est_src=est_masks,
                                     target_src=masks, mix_spec=spec)
     return loss, loss_dic
Esempio n. 10
0
def test_griffinlim(fb_config, feed_istft, feed_angle):
    stft = Encoder(STFTFB(**fb_config))
    istft = None if not feed_istft else Decoder(STFTFB(**fb_config))
    wav = torch.randn(2, 1, 8000)
    spec = stft(wav)
    tf_mask = torch.sigmoid(torch.randn_like(spec))
    masked_spec = spec * tf_mask
    mag = transforms.take_mag(masked_spec, -2)
    angles = None if not feed_angle else transforms.angle(masked_spec, -2)
    griffin_lim(mag, stft, angles=angles, istft_dec=istft, n_iter=3)
Esempio n. 11
0
 def separate(self, x):
     """ Separate with mask-inference head, output waveforms """
     if len(x.shape) == 2:
         x = x.unsqueeze(1)
     tf_rep = self.encoder(x)
     proj, mask_out = self.masker(take_mag(tf_rep))
     masked = apply_mag_mask(tf_rep.unsqueeze(1), mask_out)
     wavs = torch_utils.pad_x_to_y(self.decoder(masked), x)
     dic_out = dict(tfrep=tf_rep, mask=mask_out, masked_tfrep=masked, proj=proj)
     return wavs, dic_out
Esempio n. 12
0
def distance(estimate, target, is_complex=True):
    """ Compute the average distance in the complex plane. Makes more sense
    when the network computes a complex mask.

    Args:
        estimate (torch.Tensor): Estimate complex spectrogram.
        target (torch.Tensor): Speech target complex spectrogram.
        is_complex (bool): Whether to compute the distance in the complex or
            the magnitude space.

    Returns:
        torch.Tensor the loss value, in a tensor of size 1.
    """
    if is_complex:
        # Take the difference in the complex plane and compute the squared norm
        # of the remaining vector.
        return take_mag(estimate - target).pow(2).mean()
    else:
        # Compute the mean difference between magnitudes.
        return (take_mag(estimate) - take_mag(target)).pow(2).mean()
Esempio n. 13
0
 def unpack_data(self, batch):
     mix, sources = batch
     n_batch, n_src, n_sample = sources.shape
     new_sources = sources.view(-1, n_sample).unsqueeze(1)
     src_mag_spec = take_mag(self.enc(new_sources))
     fft_dim = src_mag_spec.shape[1]
     src_mag_spec = src_mag_spec.view(n_batch, n_src, fft_dim, -1)
     src_sum = src_mag_spec.sum(1).unsqueeze(1) + EPS
     real_mask = src_mag_spec / src_sum
     # Get the src idx having the maximum energy
     binary_mask = real_mask.argmax(1)
     return mix, binary_mask, real_mask
Esempio n. 14
0
def test_angle_mag_recompostion(dim):
    """ Test complex --> (mag, angle) --> complex conversions"""
    max_tested_ndim = 4
    # Random tensor shape
    tensor_shape = [random.randint(1, 10) for _ in range(max_tested_ndim)]
    # Make sure complex dimension has even shape
    tensor_shape[dim] = 2 * tensor_shape[dim]
    complex_tensor = torch.randn(tensor_shape)
    phase = transforms.angle(complex_tensor, dim=dim)
    mag = transforms.take_mag(complex_tensor, dim=dim)
    tensor_back = transforms.from_mag_and_phase(mag, phase, dim=dim)
    assert_allclose(complex_tensor, tensor_back)
Esempio n. 15
0
    def forward(self, x, z):
        """
        Forward pass of discriminator.
        Args:
            x: inputs
            z: clean
        """
        # Encode
        x = self.encoder(x)
        x = take_mag(x)
        x = x.unsqueeze(1)

        # Encode
        z = self.encoder(z)
        z = take_mag(z)
        z = z.unsqueeze(1)

        x = torch.cat((x, z), dim=1)
        x = self.conv(x)
        x = self.pool(x).squeeze()
        x = self.linear(x)
        return x
Esempio n. 16
0
 def forward(self, x):
     if len(x.shape) == 2:
         x = x.unsqueeze(1)
     # Compute STFT
     tf_rep = self.encoder(x)
     # Estimate TF mask from STFT features : cat([re, im, mag])
     if self.is_complex:
         to_masker = take_cat(tf_rep)
     else:
         to_masker = take_mag(tf_rep)
     # LSTM masker expects a feature dimension last (not like 1D conv)
     est_masks = self.masker(to_masker.transpose(1, 2)).transpose(1, 2)
     # Apply TF mask
     if self.is_complex:
         masked_tf_rep = apply_real_mask(tf_rep, est_masks)
     else:
         masked_tf_rep = apply_mag_mask(tf_rep, est_masks)
     return masked_tf_rep
Esempio n. 17
0
def test_misi(fb_config, feed_istft, feed_angle):
    stft = Encoder(STFTFB(**fb_config))
    istft = None if not feed_istft else Decoder(STFTFB(**fb_config))
    n_src = 3
    # Create mixture
    wav = torch.randn(2, 1, 8000)
    spec = stft(wav).unsqueeze(1)
    # Create n_src masks on mixture spec and apply them
    shape = list(spec.shape)
    shape[1] *= n_src
    tf_mask = torch.sigmoid(torch.randn(*shape))
    masked_specs = spec * tf_mask
    # Separate mag and angle.
    mag = transforms.take_mag(masked_specs, -2)
    angles = None if not feed_angle else transforms.angle(masked_specs, -2)
    est_wavs = misi(wav, mag, stft, angles=angles, istft_dec=istft, n_iter=2)
    # We actually don't know the last dim because ISTFT(STFT()) cuts the end
    assert est_wavs.shape[:-1] == (2, n_src)
Esempio n. 18
0
    def forward(self, x):
        """
        Forward pass of generator.
        Args:
            x: input batch (signal)
        """

        # Encode
        spec = self.encoder(x)
        mag = take_mag(spec)
        # x = nn.utils.spectral_norm(x)
        mag = torch.transpose(mag, 1, 2)
        # Compute mask
        self.LSTM.flatten_parameters()
        mask, _ = self.LSTM(mag)
        mask = self.model(mask)
        mask = torch.transpose(mask, 1, 2)
        y = apply_mag_mask(spec, mask)
        # Decode
        y = self.decoder(y)
        return torch_utils.pad_x_to_y(y, x)
Esempio n. 19
0
def test_mag(encoder_list):
    for (enc, fb_dim) in encoder_list:
        tf_rep = enc(torch.randn(2, 1, 16000))  # [batch, freq, time]
        batch, freq, time = tf_rep.shape
        mag = transforms.take_mag(tf_rep, dim=1)
        assert mag.shape == (batch, freq // 2, time)
Esempio n. 20
0
 def forward(self, x):
     if len(x.shape) == 2:
         x = x.unsqueeze(1)
     tf_rep = take_mag(self.encoder(x))
     embedding = self.masker(tf_rep)
     return embedding
Esempio n. 21
0
 def forward(self, x):
     if len(x.shape) == 2:
         x = x.unsqueeze(1)
     tf_rep = self.encoder(x)
     final_proj, mask_out = self.masker(take_mag(tf_rep))
     return final_proj, mask_out
Esempio n. 22
0
 def common_step(self, batch, batch_nb, train=False):
     inputs, targets, masks = self.unpack_data(batch)
     est_targets = self(inputs)
     spec = take_mag(self.enc(inputs.unsqueeze(1)))
     loss = self.loss_func(est_targets, targets, masks, spec)
     return loss