def forward(self, ref, inf) -> torch.Tensor: """time-frequency absolute coherence loss. Reference: Independent Vector Analysis with Deep Neural Network Source Priors; Li et al 2020; https://arxiv.org/abs/2008.11273 Args: ref: (Batch, T, F) or (Batch, T, C, F) inf: (Batch, T, F) or (Batch, T, C, F) Returns: loss: (Batch,) """ assert ref.shape == inf.shape, (ref.shape, inf.shape) if is_complex(ref) and is_complex(inf): # sqrt( E[|inf|^2] * E[|ref|^2] ) denom = ( complex_norm(ref, dim=1) * complex_norm(inf, dim=1) / ref.size(1) + EPS ) coh = (inf * ref.conj()).mean(dim=1).abs() / denom if ref.dim() == 3: coh_loss = 1.0 - coh.mean(dim=1) elif ref.dim() == 4: coh_loss = 1.0 - coh.mean(dim=[1, 2]) else: raise ValueError( "Invalid input shape: ref={}, inf={}".format(ref.shape, inf.shape) ) else: raise ValueError("`ref` and `inf` must be complex tensors.") return coh_loss
def inverse( self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Inverse STFT. Args: input: Tensor(batch, T, F, 2) or ComplexTensor(batch, T, F) ilens: (batch,) Returns: wavs: (batch, samples) ilens: (batch,) """ if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): istft = torch.functional.istft else: try: import torchaudio except ImportError: raise ImportError( "Please install torchaudio>=0.3.0 or use torch>=1.6.0") if not hasattr(torchaudio.functional, "istft"): raise ImportError( "Please install torchaudio>=0.3.0 or use torch>=1.6.0") istft = torchaudio.functional.istft if self.window is not None: window_func = getattr(torch, f"{self.window}_window") if is_complex(input): datatype = input.real.dtype else: datatype = input.dtype window = window_func(self.win_length, dtype=datatype, device=input.device) else: window = None if is_complex(input): input = torch.stack([input.real, input.imag], dim=-1) elif input.shape[-1] != 2: raise TypeError("Invalid input type") input = input.transpose(1, 2) wavs = istft( input, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window=window, center=self.center, normalized=self.normalized, onesided=self.onesided, length=ilens.max() if ilens is not None else ilens, ) return wavs, ilens
def forward(self, ref, inf) -> torch.Tensor: """time-frequency MSE loss. Args: ref: (Batch, T, F) or (Batch, T, C, F) inf: (Batch, T, F) or (Batch, T, C, F) Returns: loss: (Batch,) """ assert ref.shape == inf.shape, (ref.shape, inf.shape) diff = ref - inf if is_complex(diff): mseloss = diff.real**2 + diff.imag**2 else: mseloss = diff**2 if ref.dim() == 3: mseloss = mseloss.mean(dim=[1, 2]) elif ref.dim() == 4: mseloss = mseloss.mean(dim=[1, 2, 3]) else: raise ValueError( "Invalid input shape: ref={}, inf={}".format(ref.shape, inf.shape) ) return mseloss
def tf_l1_loss(ref, inf): """time-frequency L1 loss. Args: ref: (Batch, T, F) or (Batch, T, C, F) inf: (Batch, T, F) or (Batch, T, C, F) Returns: loss: (Batch,) """ assert ref.shape == inf.shape, (ref.shape, inf.shape) if not is_torch_1_3_plus: # in case of binary masks ref = ref.type(inf.dtype) if is_complex(inf): l1loss = abs(ref - inf + EPS) else: l1loss = abs(ref - inf) if ref.dim() == 3: l1loss = l1loss.mean(dim=[1, 2]) elif ref.dim() == 4: l1loss = l1loss.mean(dim=[1, 2, 3]) else: raise ValueError( "Invalid input shape: ref={}, inf={}".format(ref.shape, inf.shape) ) return l1loss
def tf_log_mse_loss(ref, inf): """time-frequency log-MSE loss. Args: ref: (Batch, T, F) or (Batch, T, C, F) inf: (Batch, T, F) or (Batch, T, C, F) Returns: loss: (Batch,) """ assert ref.shape == inf.shape, (ref.shape, inf.shape) if not is_torch_1_3_plus: # in case of binary masks ref = ref.type(inf.dtype) diff = ref - inf if is_complex(diff): log_mse_loss = diff.real ** 2 + diff.imag ** 2 else: log_mse_loss = diff ** 2 if ref.dim() == 3: log_mse_loss = torch.log10(log_mse_loss.sum(dim=[1, 2])) * 10 elif ref.dim() == 4: log_mse_loss = torch.log10(log_mse_loss.sum(dim=[1, 2, 3])) * 10 else: raise ValueError( "Invalid input shape: ref={}, inf={}".format(ref.shape, inf.shape) ) return log_mse_loss
def forward(self, ref, inf) -> torch.Tensor: """time-frequency L1 loss. Args: ref: (Batch, T, F) or (Batch, T, C, F) inf: (Batch, T, F) or (Batch, T, C, F) Returns: loss: (Batch,) """ assert ref.shape == inf.shape, (ref.shape, inf.shape) if is_complex(inf): l1loss = ( abs(ref.real - inf.real) + abs(ref.imag - inf.imag) + abs(ref.abs() - inf.abs()) ) else: l1loss = abs(ref - inf) if ref.dim() == 3: l1loss = l1loss.mean(dim=[1, 2]) elif ref.dim() == 4: l1loss = l1loss.mean(dim=[1, 2, 3]) else: raise ValueError( "Invalid input shape: ref={}, inf={}".format(ref.shape, inf.shape) ) return l1loss
def forward( self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor, additional: Optional[Dict] = None, ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: """DC-CRN Separator Forward. Args: input (torch.Tensor or ComplexTensor): Encoded feature [Batch, T, F] or [Batch, T, C, F] ilens (torch.Tensor): input lengths [Batch,] Returns: masked (List[Union(torch.Tensor, ComplexTensor)]): [(Batch, T, F), ...] ilens (torch.Tensor): (B,) others predicted data, e.g. masks: OrderedDict[ 'mask_spk1': torch.Tensor(Batch, Frames, Freq), 'mask_spk2': torch.Tensor(Batch, Frames, Freq), ... 'mask_spkn': torch.Tensor(Batch, Frames, Freq), ] """ assert is_complex(input) is_multichannel = input.ndim == 4 if is_multichannel: feature = torch.cat([input.real, input.imag], dim=2).permute(0, 2, 1, 3) else: feature = torch.stack([input.real, input.imag], dim=1) masks = self.dc_crn(feature) masks = [new_complex_like(input, m.unbind(dim=1)) for m in masks.unbind(dim=2)] if self.predict_noise: *masks, mask_noise = masks if self.mode == "masking": if is_multichannel: masked = [input * m.unsqueeze(2) for m in masks] else: masked = [input * m for m in masks] else: masked = masks if is_multichannel: masks = [m.unsqueeze(2) / (input + EPS) for m in masked] else: masks = [m / (input + EPS) for m in masked] others = OrderedDict( zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks) ) if self.predict_noise: mask_noise = mask_noise.unsqueeze(2) if is_multichannel else mask_noise if self.mode == "masking": others["noise1"] = input * mask_noise else: others["noise1"] = mask_noise return masked, ilens, others
def forward( self, xs: Union[torch.Tensor, ComplexTensor], ilens: torch.LongTensor ) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]: """Mask estimator forward function. Args: xs: (B, F, C, T) ilens: (B,) Returns: hs (torch.Tensor): The hidden vector (B, F, C, T) masks: A tuple of the masks. (B, F, C, T) ilens: (B,) """ assert xs.size(0) == ilens.size(0), (xs.size(0), ilens.size(0)) _, _, C, input_length = xs.size() # (B, F, C, T) -> (B, C, T, F) xs = xs.permute(0, 2, 3, 1) # Calculate amplitude: (B, C, T, F) -> (B, C, T, F) if is_complex(xs): xs = (xs.real**2 + xs.imag**2)**0.5 # xs: (B, C, T, F) -> xs: (B * C, T, F) xs = xs.contiguous().view(-1, xs.size(-2), xs.size(-1)) # ilens: (B,) -> ilens_: (B * C) ilens_ = ilens[:, None].expand(-1, C).contiguous().view(-1) # xs: (B * C, T, F) -> xs: (B * C, T, D) xs, _, _ = self.brnn(xs, ilens_) # xs: (B * C, T, D) -> xs: (B, C, T, D) xs = xs.view(-1, C, xs.size(-2), xs.size(-1)) masks = [] for linear in self.linears: # xs: (B, C, T, D) -> mask:(B, C, T, F) mask = linear(xs) if self.nonlinear == "sigmoid": mask = torch.sigmoid(mask) elif self.nonlinear == "relu": mask = torch.relu(mask) elif self.nonlinear == "tanh": mask = torch.tanh(mask) elif self.nonlinear == "crelu": mask = torch.clamp(mask, min=0, max=1) # Zero padding mask.masked_fill(make_pad_mask(ilens, mask, length_dim=2), 0) # (B, C, T, F) -> (B, F, C, T) mask = mask.permute(0, 3, 1, 2) # Take cares of multi gpu cases: If input_length > max(ilens) if mask.size(-1) < input_length: mask = F.pad(mask, [0, input_length - mask.size(-1)], value=0) masks.append(mask) return tuple(masks), ilens
def forward( self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor, additional: Optional[Dict] = None, ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: """Forward. Args: input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N] ilens (torch.Tensor): input lengths [Batch] additional (Dict or None): other data included in model NOTE: not used in this model Returns: masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...] ilens (torch.Tensor): (B,) others predicted data, e.g. masks: OrderedDict[ 'mask_spk1': torch.Tensor(Batch, Frames, Freq), 'mask_spk2': torch.Tensor(Batch, Frames, Freq), ... 'mask_spkn': torch.Tensor(Batch, Frames, Freq), ] """ # if complex spectrum, if is_complex(input): feature = abs(input) else: feature = input # prepare pad_mask for transformer pad_mask = make_non_pad_mask(ilens).unsqueeze(1).to(feature.device) x, ilens = self.conformer(feature, pad_mask) masks = [] for linear in self.linear: y = linear(x) y = self.nonlinear(y) masks.append(y) if self.predict_noise: *masks, mask_noise = masks masked = [input * m for m in masks] others = OrderedDict( zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks)) if self.predict_noise: others["noise1"] = input * mask_noise return masked, ilens, others
def forward( self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor, additional: Optional[Dict] = None, ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: """Forward. Args: input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N] ilens (torch.Tensor): input lengths [Batch] additional (Dict or None): other data included in model NOTE: not used in this model Returns: masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...] ilens (torch.Tensor): (B,) others predicted data, e.g. masks: OrderedDict[ 'mask_spk1': torch.Tensor(Batch, Frames, Freq), 'mask_spk2': torch.Tensor(Batch, Frames, Freq), ... 'mask_spkn': torch.Tensor(Batch, Frames, Freq), ] """ # if complex spectrum, if is_complex(input): feature = abs(input) else: feature = input B, T, N = feature.shape feature = feature.transpose(1, 2) # B, N, T segmented, rest = split_feature( feature, segment_size=self.segment_size) # B, N, L, K processed = self.dprnn(segmented) # B, N*num_spk, L, K processed = merge_feature(processed, rest) # B, N*num_spk, T processed = processed.transpose(1, 2) # B, T, N*num_spk processed = processed.view(B, T, N, self.num_spk) masks = self.nonlinear(processed).unbind(dim=3) masked = [input * m for m in masks] others = OrderedDict( zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks)) return masked, ilens, others
def forward( self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor, additional: Optional[Dict] = None, ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: """Forward. Args: input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N] ilens (torch.Tensor): input lengths [Batch] additional (Dict or None): other data included in model NOTE: not used in this model Returns: masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...] ilens (torch.Tensor): (B,) others predicted data, e.g. masks: OrderedDict[ 'mask_spk1': torch.Tensor(Batch, Frames, Freq), 'mask_spk2': torch.Tensor(Batch, Frames, Freq), ... 'mask_spkn': torch.Tensor(Batch, Frames, Freq), ] """ # if complex spectrum if is_complex(input): feature = abs(input) else: feature = input B, L, N = feature.shape feature = feature.transpose(1, 2) # B, N, L masks = self.tcn(feature) # B, num_spk, N, L masks = masks.transpose(2, 3) # B, num_spk, L, N if self.predict_noise: *masks, mask_noise = masks.unbind(dim=1) # List[B, L, N] else: masks = masks.unbind(dim=1) # List[B, L, N] masked = [input * m for m in masks] others = OrderedDict( zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks) ) if self.predict_noise: others["noise1"] = input * mask_noise return masked, ilens, others
def forward( self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor, additional: Optional[Dict] = None, ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: """Forward. Args: input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N] ilens (torch.Tensor): input lengths [Batch] additional (Dict or None): other data included in model NOTE: not used in this model Returns: masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...] ilens (torch.Tensor): (B,) others predicted data, e.g. masks: OrderedDict[ 'mask_spk1': torch.Tensor(Batch, Frames, Freq), 'mask_spk2': torch.Tensor(Batch, Frames, Freq), ... 'mask_spkn': torch.Tensor(Batch, Frames, Freq), ] """ # if complex spectrum, if is_complex(input): feature = abs(input) else: feature = input x, ilens, _ = self.rnn(feature, ilens) masks = [] for linear in self.linear: y = linear(x) y = self.nonlinear(y) masks.append(y) masked = [input * m for m in masks] others = OrderedDict( zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks)) return masked, ilens, others
def test_dc_crn_separator_forward_backward_complex( input_dim, num_spk, input_channels, enc_hid_channels, enc_layers, glstm_groups, glstm_layers, glstm_bidirectional, glstm_rearrange, mode, ): model = DC_CRNSeparator( input_dim=input_dim, num_spk=num_spk, input_channels=input_channels, enc_hid_channels=enc_hid_channels, enc_kernel_size=(1, 3), enc_padding=(0, 1), enc_last_kernel_size=(1, 3), enc_last_stride=(1, 2), enc_last_padding=(0, 1), enc_layers=enc_layers, skip_last_kernel_size=(1, 3), skip_last_stride=(1, 1), skip_last_padding=(0, 1), glstm_groups=glstm_groups, glstm_layers=glstm_layers, glstm_bidirectional=glstm_bidirectional, glstm_rearrange=glstm_rearrange, mode=mode, ) model.train() real = torch.rand(2, 10, input_dim) imag = torch.rand(2, 10, input_dim) x = torch.complex(real, imag) if is_torch_1_9_plus else ComplexTensor( real, imag) x_lens = torch.tensor([10, 8], dtype=torch.long) masked, flens, others = model(x, ilens=x_lens) assert is_complex(masked[0]) assert len(masked) == num_spk masked[0].abs().mean().backward()
def test_dc_crn_separator_multich_input( num_spk, input_channels, enc_kernel_size, enc_padding, enc_last_kernel_size, enc_last_stride, enc_last_padding, skip_last_kernel_size, skip_last_stride, skip_last_padding, ): model = DC_CRNSeparator( input_dim=33, num_spk=num_spk, input_channels=input_channels, enc_hid_channels=2, enc_kernel_size=enc_kernel_size, enc_padding=enc_padding, enc_last_kernel_size=enc_last_kernel_size, enc_last_stride=enc_last_stride, enc_last_padding=enc_last_padding, enc_layers=3, skip_last_kernel_size=skip_last_kernel_size, skip_last_stride=skip_last_stride, skip_last_padding=skip_last_padding, ) model.train() real = torch.rand(2, 10, input_channels[0] // 2, 33) imag = torch.rand(2, 10, input_channels[0] // 2, 33) x = torch.complex(real, imag) if is_torch_1_9_plus else ComplexTensor( real, imag) x_lens = torch.tensor([10, 8], dtype=torch.long) masked, flens, others = model(x, ilens=x_lens) assert is_complex(masked[0]) assert len(masked) == num_spk masked[0].abs().mean().backward()
def forward( self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: """Forward. Args: input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N] ilens (torch.Tensor): input lengths [Batch] Returns: masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...] ilens (torch.Tensor): (B,) others predicted data, e.g. masks: OrderedDict[ 'mask_spk1': torch.Tensor(Batch, Frames, Freq), 'mask_spk2': torch.Tensor(Batch, Frames, Freq), ... 'mask_spkn': torch.Tensor(Batch, Frames, Freq), ] """ # if complex spectrum if is_complex(input): feature = abs(input) else: feature = input B, L, N = feature.shape feature = feature.transpose(1, 2) # B, N, L masks = self.tcn(feature) # B, num_spk, N, L masks = masks.transpose(2, 3) # B, num_spk, L, N masks = masks.unbind(dim=1) # List[B, L, N] masked = [input * m for m in masks] others = OrderedDict( zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks)) return masked, ilens, others
def forward( self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: """Forward. Args: input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N] ilens (torch.Tensor): input lengths [Batch] Returns: masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...] ilens (torch.Tensor): (B,) others predicted data, e.g. masks: OrderedDict[ 'mask_spk1': torch.Tensor(Batch, Frames, Freq), 'mask_spk2': torch.Tensor(Batch, Frames, Freq), ... 'mask_spkn': torch.Tensor(Batch, Frames, Freq), ] """ # if complex spectrum, if is_complex(input): feature = abs(input) else: feature = input B, T, N = feature.shape processed = self.skim(feature) # B,T, N processed = processed.view(B, T, N, self.num_spk) masks = self.nonlinear(processed).unbind(dim=3) masked = [input * m for m in masks] others = OrderedDict( zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks) ) return masked, ilens, others
def signal_framing( signal: Union[torch.Tensor, ComplexTensor], frame_length: int, frame_step: int, bdelay: int, do_padding: bool = False, pad_value: int = 0, indices: List = None, ) -> Union[torch.Tensor, ComplexTensor]: """Expand `signal` into several frames, with each frame of length `frame_length`. Args: signal : (..., T) frame_length: length of each segment frame_step: step for selecting frames bdelay: delay for WPD do_padding: whether or not to pad the input signal at the beginning of the time dimension pad_value: value to fill in the padding Returns: torch.Tensor: if do_padding: (..., T, frame_length) else: (..., T - bdelay - frame_length + 2, frame_length) """ if isinstance(signal, ComplexTensor): complex_wrapper = ComplexTensor pad_func = FC.pad elif is_torch_complex_tensor(signal): complex_wrapper = torch.complex pad_func = torch.nn.functional.pad else: pad_func = torch.nn.functional.pad frame_length2 = frame_length - 1 # pad to the right at the last dimension of `signal` (time dimension) if do_padding: # (..., T) --> (..., T + bdelay + frame_length - 2) signal = pad_func(signal, (bdelay + frame_length2 - 1, 0), "constant", pad_value) do_padding = False if indices is None: # [[ 0, 1, ..., frame_length2 - 1, frame_length2 - 1 + bdelay ], # [ 1, 2, ..., frame_length2, frame_length2 + bdelay ], # [ 2, 3, ..., frame_length2 + 1, frame_length2 + 1 + bdelay ], # ... # [ T-bdelay-frame_length2, ..., T-1-bdelay, T-1 ]] indices = [[ *range(i, i + frame_length2), i + frame_length2 + bdelay - 1 ] for i in range(0, signal.shape[-1] - frame_length2 - bdelay + 1, frame_step)] if is_complex(signal): real = signal_framing( signal.real, frame_length, frame_step, bdelay, do_padding, pad_value, indices, ) imag = signal_framing( signal.imag, frame_length, frame_step, bdelay, do_padding, pad_value, indices, ) return complex_wrapper(real, imag) else: # (..., T - bdelay - frame_length + 2, frame_length) signal = signal[..., indices] return signal
def forward( self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor, additional: Optional[Dict] = None, ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: """Forward. Args: input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N] ilens (torch.Tensor): input lengths [Batch] additional (Dict or None): other data included in model NOTE: not used in this model Returns: masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...] ilens (torch.Tensor): (B,) others predicted data, e.g. masks: OrderedDict[ 'mask_spk1': torch.Tensor(Batch, Frames, Freq), 'mask_spk2': torch.Tensor(Batch, Frames, Freq), ... 'mask_spkn': torch.Tensor(Batch, Frames, Freq), ] """ # if complex spectrum, if is_complex(input): feature = abs(input) elif self.post_enc_relu: feature = torch.nn.functional.relu(input) else: feature = input B, T, N = feature.shape feature = feature.transpose(1, 2) # B, N, T feature = self.enc_LN(feature) segmented = self.split_feature(feature) # B, N, L, K processed = self.dptnet(segmented) # B, N*num_spk, L, K processed = processed.reshape(B * self.num_spk, -1, processed.size(-2), processed.size(-1)) # B*num_spk, N, L, K processed = self.merge_feature(processed, length=T) # B*num_spk, N, T # gated output layer for filter generation (B*num_spk, N, T) processed = self.output(processed) * self.output_gate(processed) masks = processed.reshape(B, self.num_spk, N, T) # list[(B, T, N)] masks = self.nonlinear(masks.transpose(-1, -2)).unbind(dim=1) masked = [input * m for m in masks] others = OrderedDict( zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks)) return masked, ilens, others