Esempio n. 1
0
    def forward(self, input):
        # input shape: B, T (S*K), D
        B, T, D = input.shape

        if self.seg_overlap:
            input, rest = split_feature(
                input.transpose(1, 2), segment_size=self.segment_size
            )  # B, D, K, S
            input = input.permute(0, 3, 2, 1).contiguous()  # B, S, K, D
        else:
            input, rest = self._padfeature(input=input)
            input = input.view(B, -1, self.segment_size, D)  # B, S, K, D
        B, S, K, D = input.shape

        assert K == self.segment_size

        output = input.view(B * S, K, D).contiguous()  # BS, K, D
        hc = None
        for i in range(self.num_blocks):
            output, hc = self.seg_lstms[i](output, hc)  # BS, K, D
            if self.mem_type and i < self.num_blocks - 1:
                hc = self.mem_lstms[i](hc, S)

        if self.seg_overlap:
            output = output.view(B, S, K, D).permute(0, 3, 2, 1)  # B, D, K, S
            output = merge_feature(output, rest)  # B, D, T
            output = self.output_fc(output).transpose(1, 2)

        else:
            output = output.view(B, S * K, D)[:, :T, :]  # B, T, D
            output = self.output_fc(output.transpose(1, 2)).transpose(1, 2)

        return output
Esempio n. 2
0
    def forward(
        self,
        input: Union[torch.Tensor, ComplexTensor],
        ilens: torch.Tensor,
        additional: Optional[Dict] = None,
    ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor,
               OrderedDict]:
        """Forward.

        Args:
            input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N]
            ilens (torch.Tensor): input lengths [Batch]
            additional (Dict or None): other data included in model
                NOTE: not used in this model

        Returns:
            masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...]
            ilens (torch.Tensor): (B,)
            others predicted data, e.g. masks: OrderedDict[
                'mask_spk1': torch.Tensor(Batch, Frames, Freq),
                'mask_spk2': torch.Tensor(Batch, Frames, Freq),
                ...
                'mask_spkn': torch.Tensor(Batch, Frames, Freq),
            ]
        """

        # if complex spectrum,
        if is_complex(input):
            feature = abs(input)
        else:
            feature = input

        B, T, N = feature.shape

        feature = feature.transpose(1, 2)  # B, N, T
        segmented, rest = split_feature(
            feature, segment_size=self.segment_size)  # B, N, L, K

        processed = self.dprnn(segmented)  # B, N*num_spk, L, K

        processed = merge_feature(processed, rest)  # B, N*num_spk, T

        processed = processed.transpose(1, 2)  # B, T, N*num_spk
        processed = processed.view(B, T, N, self.num_spk)
        masks = self.nonlinear(processed).unbind(dim=3)

        masked = [input * m for m in masks]

        others = OrderedDict(
            zip(["mask_spk{}".format(i + 1) for i in range(len(masks))],
                masks))

        return masked, ilens, others
Esempio n. 3
0
    def forward(
        self,
        input: torch.Tensor,
        ilens: torch.Tensor,
        additional: Optional[Dict] = None,
    ) -> Tuple[List[torch.Tensor], torch.Tensor, OrderedDict]:
        """Forward.

        Args:
            input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N]
            ilens (torch.Tensor): input lengths [Batch]
            additional (Dict or None): other data included in model
                NOTE: not used in this model

        Returns:
            masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...]
            ilens (torch.Tensor): (B,)
            others predicted data, e.g. masks: OrderedDict[
                'mask_spk1': torch.Tensor(Batch, Frames, Freq),
                'mask_spk2': torch.Tensor(Batch, Frames, Freq),
                ...
                'mask_spkn': torch.Tensor(Batch, Frames, Freq),
            ]
        """
        # fix time dimension, might change due to convolution operations
        T_mix = input.size(-1)

        mixture_w = self.encoder(input)

        enc_segments, enc_rest = split_feature(mixture_w, self.segment_size)
        # separate
        output_all = self.rnn_model(enc_segments)

        # generate wav after each RNN block and optimize the loss
        outputs = []
        for ii in range(len(output_all)):
            output_ii = merge_feature(output_all[ii], enc_rest)
            output_ii = output_ii.view(input.shape[0], self._num_spk,
                                       self.enc_dim, mixture_w.shape[2])
            output_ii = self.decoder(output_ii)
            T_est = output_ii.size(-1)
            output_ii = F.pad(output_ii, (0, T_mix - T_est))
            output_ii = list(output_ii.unbind(dim=1))
            if self.training:
                outputs.append(output_ii)
            else:
                outputs = output_ii

        others = {}
        return outputs, ilens, others
Esempio n. 4
0
    def forward(self, input, num_mic):

        # input: (B, ch, N, T)
        batch_size, ch, N, seq_length = input.shape

        input = input.view(batch_size * ch, N, seq_length)  # B*ch, N, T
        enc_feature = self.BN(input)

        # split the encoder output into overlapped, longer segments
        enc_segments, enc_rest = dprnn.split_feature(
            enc_feature, self.segment_size
        )  # B*ch, N, L, K

        enc_segments = enc_segments.view(
            batch_size, ch, -1, enc_segments.shape[2], enc_segments.shape[3]
        )  # B, ch, N, L, K
        output = self.dprnn_model(enc_segments, num_mic).view(
            batch_size * ch * self.num_spk,
            self.feature_dim,
            self.segment_size,
            -1,
        )  # B*ch*nspk, N, L, K
        # overlap-and-add of the outputs
        output = dprnn.merge_feature(output, enc_rest)  # B*ch*nspk, N, T

        if self.fasnet_type == "fasnet":
            # gated output layer for filter generation
            bf_filter = self.output(output) * self.output_gate(
                output
            )  # B*ch*nspk, K, T
            bf_filter = (
                bf_filter.transpose(1, 2)
                .contiguous()
                .view(batch_size, ch, self.num_spk, -1, self.output_dim)
            )  # B, ch, nspk, L, N

        elif self.fasnet_type == "ifasnet":
            # output layer
            bf_filter = self.output(output)  # B*ch*nspk, K, T
            bf_filter = bf_filter.view(
                batch_size, ch, self.num_spk, self.output_dim, -1
            )  # B, ch, nspk, K, L

        return bf_filter