예제 #1
0
    def forward(self, est_mask):
        """
        Args:
            mixture_w: [M, N, K]
            est_mask: [M, C, N, K]
        Returns:
            est_source: [M, C, T]
        """

        est_source = torch.transpose(est_mask, 2, 3)
        est_source = nn.AvgPool2d((1, self.L))(est_source)
        est_source = overlap_and_add(est_source, self.L // 2)  # M x C x T

        return est_source
 def forward(self, mixture_w, est_mask):
     """
     Args:
         mixture_w: [M, N, K]
         est_mask: [M, C, N, K]
     Returns:
         est_source: [M, C, T]
     """
     # D = W * M
     source_w = torch.unsqueeze(mixture_w, 1) * est_mask  # [M, C, N, K]
     source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N]
     # S = DV
     est_source = self.basis_signals(source_w)  # [M, C, K, L]
     est_source = overlap_and_add(est_source, self.L//2) # M x C x T
     return est_source
예제 #3
0
 def forward(self, mixture_w, est_mask):
     """
     Args:
         mixture_w: [M, K, N]
         est_mask: [M, K, C, N]
     Returns:
         est_source: [M, C, T]
     """
     # D = W * M
     source_w = torch.unsqueeze(mixture_w, 2) * est_mask  # M x K x C x N
     # S = DV
     est_source = self.basis_signals(source_w)  # M x K x C x L
     est_source = est_source.permute(
         (0, 2, 1, 3)).contiguous()  # M x C x K x L
     est_source = overlap_and_add(est_source, self.L // 2)  # M x C x T
     return est_source
예제 #4
0
 def forward(self, mixture_w, est_mask):
     """
     Args:
         mixture_w: [B, E, L]
         est_mask: [B, C, E, L]
     Returns:
         est_source: [B, C, T]
     """
     # D = W * M
     #print(mixture_w.shape)
     #print(est_mask.shape)
     source_w = torch.unsqueeze(mixture_w, 1) * est_mask  # [B, C, E, L]
     source_w = torch.transpose(source_w, 2, 3)  # [B, C, L, E]
     # S = DV
     est_source = self.basis_signals(source_w)  # [B, C, L, W]
     est_source = overlap_and_add(est_source, self.W // 2)  # B x C x T
     return est_source