Ejemplo n.º 1
0
    def forward(self, est_embeddings, target_indices, est_src=None, target_src=None, mix_spec=None):
        """

        Args:
            est_embeddings (torch.Tensor): Estimated embedding from the DC head.
            target_indices (torch.Tensor): Target indices that'll be passed to
                the DC loss.
            est_src (torch.Tensor): Estimated magnitude spectrograms (or masks).
            target_src (torch.Tensor): Target magnitude spectrograms (or masks).
            mix_spec (torch.Tensor): The magnitude spectrogram of the mixture
                from which VAD will be computed. If None, no VAD is used.

        Returns:
            torch.Tensor, the total loss, averaged over the batch.
            dict with `dc_loss` and `pit_loss` keys, unweighted losses.
        """
        if self.alpha != 0 and (est_src is None or target_src is None):
            raise ValueError(
                "Expected target and estimated spectrograms to " "compute the PIT loss, found None."
            )
        binary_mask = None
        if mix_spec is not None:
            binary_mask = ebased_vad(mix_spec)
        # Dc loss is already divided by VAD in the loss function.
        dc_loss = deep_clustering_loss(
            embedding=est_embeddings, tgt_index=target_indices, binary_mask=binary_mask
        )
        src_pit_loss = self.src_mse(est_src, target_src)
        # Equation (4) from Chimera paper.
        tot = self.alpha * dc_loss.mean() + (1 - self.alpha) * src_pit_loss
        # Return unweighted losses as well for logging.
        loss_dict = dict(dc_loss=dc_loss.mean(), pit_loss=src_pit_loss)
        return tot, loss_dict
Ejemplo n.º 2
0
def compute_cost(model, batch):
    inputs, targets, masks = unpack_data(batch)
    spec = take_mag(enc(inputs.unsqueeze(1)))
    #spec = take_mag(enc(batch[1][:,0,:].unsqueeze(1)))
    spec = spec.cuda()
    est_targets = model(spec)
    #masks = torch.stack((masks[:,0,...], masks[:,0,...]),dim=1)
    #masks = masks[:,0,...].permute(0,2,1)
    #masks = masks.permute(0,2,1)
    masks = masks.cuda()
    #temp = torch.rand(5,129, 300)
    #temp = temp.cuda()
    #est_targets = model(temp)
    #masks = temp.permute(0,2,1)
    #torch.save((masks.data.cpu(), spec.data.cpu()), 'mask_spec.pt')
    #loss = torch.sqrt(torch.pow(est_targets[1] - masks, 2)+EPS).mean()
    #loss = torch.pow(est_targets[1] - masks, 2).mean()
    #loss = pairwise_mse(est_targets[1], masks).mean()
    loss = pit_loss(est_targets[1], masks)
    embedding = est_targets[0]

    vad_tf_mask = compute_vad(spec).cuda()
    targets = targets.cuda()
    dc_loss = deep_clustering_loss(embedding, targets, binary_mask=vad_tf_mask)
    #loss = torch.sqrt(torch.pow(est_targets[1] - spec.permute(0,2,1), 2)).mean()
    return dc_loss, loss
Ejemplo n.º 3
0
def handle_multiple_loss(est_heads, targets, true_real_masks, inputs, alpha=1):
    """ Handles deep clustering loss and the PIT loss
    Args:
        est_heads (tuple): Tuple containing embedding and estimated masks
        targets (np.array): Binary masks with one hot encoding
        true_real_masks (np.array): True real valued masks
        inputs(np.array): Spectrogram of the mixture
        alpha(int): Weight for the pit_loss

    Return:
        Sum of the deep clustering and PIT loss
    """
    embedding, est_masks = est_heads
    dc_loss = deep_clustering_loss(embedding, targets)
    pit_loss_batch = pit_loss(est_masks * inputs.unsqueeze(1),
                              true_real_masks * inputs.unsqueeze(1))
    return dc_loss.mean() + alpha * pit_loss_batch
Ejemplo n.º 4
0
def test_dc(spk_cnt):
    embedding = torch.randn(10, 5 * 400, 20)
    targets = torch.LongTensor(10, 400, 5).random_(0, spk_cnt)
    loss = deep_clustering_loss(embedding, targets, spk_cnt)
    assert loss.shape[0] == 10
Ejemplo n.º 5
0
def test_dc(spk_cnt):
    embedding = torch.randn(10, 5 * 400, 20)
    targets = torch.zeros(10, 400, 5).random_(0, spk_cnt).long()
    loss = deep_clustering_loss(embedding, targets)
    assert loss.shape[0] == 10