def __init__( self, kernel_size, stride, in_chan, n_src, bn_chan, chunk_size, hop_size, mask_act ): super(SingleDecoder, self).__init__() self.kernel_size = kernel_size self.stride = stride self.in_chan = in_chan self.bn_chan = bn_chan self.chunk_size = chunk_size self.hop_size = hop_size self.n_src = n_src self.mask_act = mask_act # Masking in 3D space net_out_conv = nn.Conv2d(bn_chan, n_src * bn_chan, 1) self.first_out = nn.Sequential(nn.PReLU(), net_out_conv) # Gating and masking in 2D space (after fold) self.net_out = nn.Sequential(nn.Conv1d(bn_chan, bn_chan, 1), nn.Tanh()) self.net_gate = nn.Sequential(nn.Conv1d(bn_chan, bn_chan, 1), nn.Sigmoid()) self.mask_net = nn.Conv1d(bn_chan, in_chan, 1, bias=False) # Get activation function. mask_nl_class = activations.get(mask_act) # For softmax, feed the source dimension. if has_arg(mask_nl_class, "dim"): self.output_act = mask_nl_class(dim=1) else: self.output_act = mask_nl_class() _, self.trans_conv = make_enc_dec( "free", kernel_size=kernel_size, stride=stride, n_filters=in_chan )
def __init__( self, architecture, stft_n_filters=1024, stft_kernel_size=1024, stft_stride=256, sample_rate=16000.0, **masknet_kwargs, ): self.architecture = architecture self.stft_n_filters = stft_n_filters self.stft_kernel_size = stft_kernel_size self.stft_stride = stft_stride self.masknet_kwargs = masknet_kwargs encoder, decoder = make_enc_dec( "stft", n_filters=stft_n_filters, kernel_size=stft_kernel_size, stride=stft_stride, sample_rate=sample_rate, ) masker = self.masknet_class.default_architecture( architecture, **masknet_kwargs) super().__init__(encoder, masker, decoder)
def __init__( self, activation='tanh', latent_dim=64, hidden_dim_encoder=128, n_filters=1024, kernel_size=1024, stride=256, padding=3, sample_rate=8000 ): encoder, decoder = make_enc_dec( "stft", n_filters=n_filters, kernel_size=kernel_size, stride=stride, sample_rate=sample_rate, ) self.activation = activation self.latent_dim = latent_dim self.hidden_dim_encoder = hidden_dim_encoder self.input_dim = n_filters // 2 + 1 self.padding = padding masker = VAE_inner(self.input_dim, self.latent_dim, self.hidden_dim_encoder, self.padding, self.activation) super().__init__(encoder, masker, decoder) self.register_buffer('scaler_mean', torch.zeros(self.input_dim)) self.register_buffer('scaler_std', torch.zeros(self.input_dim)) self.has_scaler = False
def make_model_and_optimizer(conf): """Function to define the model and optimizer for a config dictionary. Args: conf: Dictionary containing the output of hierachical argparse. Returns: model, optimizer. The main goal of this function is to make reloading for resuming and evaluation very simple. """ # Define building blocks for local model # The encoder and decoder can directly be made from the dictionary. encoder, decoder = fb.make_enc_dec(**conf["filterbank"]) # The input post-processing changes the dimensions of input features to # the mask network. Different type of masks impose different output # dimensions to the mask network's output. We correct for these here. nn_in = int(encoder.n_feats_out * encoder.in_chan_mul) nn_out = int(encoder.n_feats_out * encoder.out_chan_mul) masker = TDConvNet(in_chan=nn_in, out_chan=nn_out, **conf["masknet"]) # Another possibility is to correct for these effects inside of Model, # but then instantiation of masker should also be done inside. model = Model(encoder, masker, decoder) # The model is defined in Container, which is passed to DataParallel. # Define optimizer : can be instantiate from dictonary as well. optimizer = make_optimizer(model.parameters(), **conf["optim"]) return model, optimizer
def make_model_and_optimizer(conf): """Function to define the model and optimizer for a config dictionary. Args: conf: Dictionary containing the output of hierachical argparse. Returns: model, optimizer. The main goal of this function is to make reloading for resuming and evaluation very simple. """ # Define building blocks for local model stft, istft = make_enc_dec("stft", **conf["filterbank"]) # Because we concatenate (re, im, mag) as input and compute a complex mask. if conf["main_args"]["is_complex"]: inp_size = int(stft.n_feats_out * 3 / 2) output_size = stft.n_feats_out else: inp_size = output_size = int(stft.n_feats_out / 2) # Add these fields to the mask model dict conf["masknet"].update(dict(input_size=inp_size, output_size=output_size)) masker = SimpleModel(**conf["masknet"]) # Make the complete model model = Model(stft, masker, istft, is_complex=conf["main_args"]["is_complex"]) # Define optimizer of this model optimizer = make_optimizer(model.parameters(), **conf["optim"]) return model, optimizer
def __init__( self, n_src, out_chan=None, bn_chan=128, hid_size=128, chunk_size=100, hop_size=None, n_repeats=6, norm_type="gLN", mask_act="sigmoid", bidirectional=True, rnn_type="LSTM", num_layers=1, dropout=0, in_chan=None, fb_name="free", kernel_size=16, n_filters=64, stride=8, encoder_activation=None, sample_rate=8000, **fb_kwargs, ): encoder, decoder = make_enc_dec( fb_name, kernel_size=kernel_size, n_filters=n_filters, stride=stride, sample_rate=sample_rate, **fb_kwargs, ) n_feats = encoder.n_feats_out if in_chan is not None: assert in_chan == n_feats, ("Number of filterbank output channels" " and number of input channels should " "be the same. Received " f"{n_feats} and {in_chan}") # Update in_chan masker = DPRNN( n_feats, n_src, out_chan=out_chan, bn_chan=bn_chan, hid_size=hid_size, chunk_size=chunk_size, hop_size=hop_size, n_repeats=n_repeats, norm_type=norm_type, mask_act=mask_act, bidirectional=bidirectional, rnn_type=rnn_type, num_layers=num_layers, dropout=dropout, ) super().__init__(encoder, masker, decoder, encoder_activation=encoder_activation)
def test_perfect_istft_default_parameters(fb_config): """ Unit test perfect reconstruction with default values. """ kernel_size = fb_config["kernel_size"] enc, dec = make_enc_dec("stft", **fb_config) inp_wav = torch.randn(2, 1, 32000) out_wav = dec(enc(inp_wav))[:, :, kernel_size:-kernel_size] inp_test = inp_wav[:, :, kernel_size:-kernel_size] testing.assert_allclose(inp_test, out_wav)
def __init__( self, n_srcs, bn_chan=128, hid_size=128, chunk_size=100, hop_size=None, n_repeats=6, norm_type="gLN", mask_act="sigmoid", bidirectional=True, rnn_type="LSTM", num_layers=1, dropout=0, kernel_size=16, n_filters=64, stride=8, encoder_activation=None, use_mulcat=False, sample_rate=8000, ): super().__init__(sample_rate=sample_rate) self.encoder_activation = encoder_activation self.enc_activation = activations.get(encoder_activation or "linear")() hop_size = hop_size if hop_size is not None else chunk_size // 2 self.encoder, _ = make_enc_dec( "free", kernel_size=kernel_size, n_filters=n_filters, stride=stride, ) # Update in_chan self.masker = DPRNN_MultiStage( in_chan=n_filters, bn_chan=bn_chan, hid_size=hid_size, chunk_size=chunk_size, hop_size=hop_size, n_repeats=n_repeats, norm_type=norm_type, bidirectional=bidirectional, rnn_type=rnn_type, use_mulcat=use_mulcat, num_layers=num_layers, dropout=dropout, ) self.decoder_select = Decoder_Select( kernel_size=kernel_size, stride=stride, in_chan=n_filters, n_srcs=n_srcs, bn_chan=bn_chan, chunk_size=chunk_size, hop_size=hop_size, mask_act=mask_act, ) """
def __init__(self, target="TCS", inner_channels=64, dilated_layers=10, total_layers=13, max_dilation=None, n_filters=2048, kernel_size=2048, stride=1024, sample_rate=8000): encoder, decoder = make_enc_dec( "stft", n_filters=n_filters, kernel_size=kernel_size, stride=stride, sample_rate=sample_rate, ) assert target in ("TMS", "TCS", "cIRM") self.target = target self.inner_channels = inner_channels self.dilated_layers = dilated_layers self.total_layers = total_layers self.max_dilation = max_dilation input_channels = 2 if self.target in ("cIRM", "TCS") else 1 layers = [] prev_ch = input_channels num_square_layers = self.total_layers - self.dilated_layers assert num_square_layers > 0 for idx in range(self.dilated_layers): dilation = 2**idx if self.max_dilation is not None: dilation = min(dilation, self.max_dilation) layers.append( SMoLnetDilatedLayer(prev_ch, self.inner_channels, dilation=dilation)) prev_ch = self.inner_channels for idx in range(num_square_layers): layers.append( SMoLnetLateLayer(self.inner_channels, self.inner_channels)) if self.target == "TMS": layers.append(nn.Conv2d(self.inner_channels, 1, kernel_size=1)) layers.append(nn.Softplus()) else: layers.append(nn.Conv2d(self.inner_channels, 2, kernel_size=1)) masker = nn.Sequential(*layers) super().__init__(encoder, masker, decoder)
def test_stft_def(fb_config): """ Check consistency between two calls.""" fb = STFTFB(**fb_config) enc = Encoder(fb) dec = Decoder(fb) enc2, dec2 = make_enc_dec("stft", **fb_config) testing.assert_allclose(enc.filterbank.filters(), enc2.filterbank.filters()) testing.assert_allclose(dec.filterbank.filters(), dec2.filterbank.filters())
def __init__( self, n_src, n_heads=4, ff_hid=256, chunk_size=100, hop_size=None, n_repeats=6, norm_type="gLN", ff_activation="relu", encoder_activation="relu", mask_act="relu", bidirectional=True, dropout=0, in_chan=None, fb_name="free", kernel_size=16, n_filters=64, stride=8, sample_rate=8000, **fb_kwargs, ): encoder, decoder = make_enc_dec( fb_name, kernel_size=kernel_size, n_filters=n_filters, stride=stride, sample_rate=sample_rate, **fb_kwargs, ) n_feats = encoder.n_feats_out if in_chan is not None: assert in_chan == n_feats, ("Number of filterbank output channels" " and number of input channels should " "be the same. Received " f"{n_feats} and {in_chan}") # Update in_chan masker = DPTransformer( n_feats, n_src, n_heads=n_heads, ff_hid=ff_hid, ff_activation=ff_activation, chunk_size=chunk_size, hop_size=hop_size, n_repeats=n_repeats, norm_type=norm_type, mask_act=mask_act, bidirectional=bidirectional, dropout=dropout, ) super().__init__(encoder, masker, decoder, encoder_activation=encoder_activation)
def __init__( self, n_src, out_chan=None, n_blocks=8, n_repeats=3, bn_chan=128, hid_chan=512, skip_chan=128, conv_kernel_size=3, norm_type="gLN", mask_act="sigmoid", in_chan=None, causal=False, fb_name="free", kernel_size=16, n_filters=512, stride=8, encoder_activation=None, sample_rate=8000, **fb_kwargs, ): encoder, decoder = make_enc_dec( fb_name, kernel_size=kernel_size, n_filters=n_filters, stride=stride, sample_rate=sample_rate, **fb_kwargs, ) n_feats = encoder.n_feats_out if in_chan is not None: assert in_chan == n_feats, ("Number of filterbank output channels" " and number of input channels should " "be the same. Received " f"{n_feats} and {in_chan}") # Update in_chan masker = TDConvNet( n_feats, n_src, out_chan=out_chan, n_blocks=n_blocks, n_repeats=n_repeats, bn_chan=bn_chan, hid_chan=hid_chan, skip_chan=skip_chan, conv_kernel_size=conv_kernel_size, norm_type=norm_type, mask_act=mask_act, causal=causal, ) super().__init__(encoder, masker, decoder, encoder_activation=encoder_activation)
def __init__( self, input_type="mag", output_type="mag", hidden_dims=(1024, ), dropout=0.0, activation="relu", mask_act="relu", norm_type="gLN", fb_name="stft", n_filters=512, stride=256, kernel_size=512, sample_rate=16000, **fb_kwargs, ): fb_type = fb_kwargs.pop("fb_type", None) if fb_type: warnings.warn( "Using `fb_type` keyword argument is deprecated and " "will be removed in v0.4.0. Use `fb_name` instead.", VisibleDeprecationWarning, ) fb_name = fb_type encoder, decoder = make_enc_dec( fb_name, kernel_size=kernel_size, n_filters=n_filters, stride=stride, sample_rate=sample_rate, **fb_kwargs, ) n_masker_in = self._get_n_feats_input(input_type, encoder.n_feats_out) n_masker_out = self._get_n_feats_output(output_type, encoder.n_feats_out) masker = build_demask_masker( n_masker_in, n_masker_out, norm_type=norm_type, activation=activation, hidden_dims=hidden_dims, dropout=dropout, mask_act=mask_act, ) super().__init__(encoder, masker, decoder) self.input_type = input_type self.output_type = output_type self.hidden_dims = hidden_dims self.dropout = dropout self.activation = activation self.mask_act = mask_act self.norm_type = norm_type
def __init__( self, n_src, out_chan=None, rnn_type="lstm", n_layers=4, hid_size=512, dropout=0.3, mask_act="sigmoid", bidirectional=True, in_chan=None, fb_name="free", n_filters=64, kernel_size=16, stride=8, encoder_activation=None, sample_rate=8000, **fb_kwargs, ): encoder, decoder = make_enc_dec( fb_name, kernel_size=kernel_size, n_filters=n_filters, stride=stride, sample_rate=sample_rate, **fb_kwargs, ) n_feats = encoder.n_feats_out if in_chan is not None: assert in_chan == n_feats, ("Number of filterbank output channels" " and number of input channels should " "be the same. Received " f"{n_feats} and {in_chan}") # Real gated encoder encoder = _GatedEncoder(encoder) # Masker masker = LSTMMasker( n_feats, n_src, out_chan=out_chan, hid_size=hid_size, mask_act=mask_act, bidirectional=bidirectional, rnn_type=rnn_type, n_layers=n_layers, dropout=dropout, ) super().__init__(encoder, masker, decoder, encoder_activation=encoder_activation)
def make_model_and_optimizer(conf): """Function to define the model and optimizer for a config dictionary. Args: conf: Dictionary containing the output of hierachical argparse. Returns: model, optimizer. The main goal of this function is to make reloading for resuming and evaluation very simple. """ enc, dec = fb.make_enc_dec("stft", **conf["filterbank"]) masker = Chimera(enc.n_feats_out // 2, **conf["masknet"]) model = Model(enc, masker, dec) optimizer = make_optimizer(model.parameters(), **conf["optim"]) return model, optimizer
def test_jit_filterbanks_enc(filter_bank_name, inference_data): n_filters = 32 if filter_bank_name == TorchSTFTFB: kernel_size = n_filters else: kernel_size = n_filters // 2 enc, _ = make_enc_dec(filter_bank_name, n_filters=n_filters, kernel_size=kernel_size) inputs = ((torch.rand(1, 200) - 0.5) * 2,) traced = torch.jit.trace(enc, inputs) with torch.no_grad(): res = enc(inference_data) out = traced(inference_data) print(traced.code_with_constants) print(traced.code) assert_allclose(res, out)
def __init__( self, fb_name="free", kernel_size=16, n_filters=32, stride=8, **fb_kwargs, ): super().__init__() if fb_name == TorchSTFTFB: n_filters = kernel_size encoder, decoder = make_enc_dec( fb_name, kernel_size=kernel_size, n_filters=n_filters, stride=stride, **fb_kwargs ) self.encoder = encoder self.decoder = decoder
def make_model_and_optimizer(conf): """Function to define the model and optimizer for a config dictionary. Args: conf: Dictionary containing the output of hierachical argparse. Returns: model, optimizer. The main goal of this function is to make reloading for resuming and evaluation very simple. """ # Define building blocks for local model enc, dec = fb.make_enc_dec("free", **conf["filterbank"]) masker = DPRNN(**conf["masknet"]) model = Model(enc, masker, dec) # Define optimizer of this model optimizer = make_optimizer(model.parameters(), **conf["optim"]) return model, optimizer
def __init__( self, n_src, bn_chan=128, num_blocks=16, upsampling_depth=4, mask_act="relu", in_chan=None, fb_name="free", kernel_size=21, n_filters=512, stride=None, sample_rate=8000, **fb_kwargs, ): stride = kernel_size // 2 if not stride else stride # Need the encoder to determine the number of input channels enc, dec = make_enc_dec( fb_name, kernel_size=kernel_size, n_filters=n_filters, stride=stride, sample_rate=sample_rate, padding=kernel_size // 2, output_padding=(kernel_size // 2) - 1, **fb_kwargs, ) n_feats = enc.n_feats_out enc = _Padder(enc, upsampling_depth=upsampling_depth, kernel_size=kernel_size) if in_chan is not None: assert in_chan == n_feats, ("Number of filterbank output channels" " and number of input channels should " "be the same. Received " f"{n_feats} and {in_chan}") masker = SuDORMRFImproved( n_feats, n_src, bn_chan=bn_chan, num_blocks=num_blocks, upsampling_depth=upsampling_depth, mask_act=mask_act, ) super().__init__(enc, masker, dec, encoder_activation=None)
def __init__( self, input_type="mag", output_type="mag", hidden_dims=(1024, ), dropout=0.0, activation="relu", mask_act="relu", norm_type="gLN", fb_name="stft", n_filters=512, stride=256, kernel_size=512, sample_rate=16000, **fb_kwargs, ): encoder, decoder = make_enc_dec( fb_name, kernel_size=kernel_size, n_filters=n_filters, stride=stride, sample_rate=sample_rate, **fb_kwargs, ) n_masker_in = self._get_n_feats_input(input_type, encoder.n_feats_out) n_masker_out = self._get_n_feats_output(output_type, encoder.n_feats_out) masker = build_demask_masker( n_masker_in, n_masker_out, norm_type=norm_type, activation=activation, hidden_dims=hidden_dims, dropout=dropout, mask_act=mask_act, ) super().__init__(encoder, masker, decoder) self.input_type = input_type self.output_type = output_type self.hidden_dims = hidden_dims self.dropout = dropout self.activation = activation self.mask_act = mask_act self.norm_type = norm_type
def __init__(self, median_kernel_size=11, n_filters=2048, kernel_size=2048, stride=512, sample_rate=8000): encoder, decoder = make_enc_dec( "stft", n_filters=n_filters, kernel_size=kernel_size, stride=stride, sample_rate=sample_rate, ) self.median_kernel_size = median_kernel_size self.pad = kernel_size // 2 super().__init__(encoder, nn.Identity(), decoder)
def __init__(self, n_filters=512, kernel_size=400, stride=100, sample_rate=8000, **masker_kwargs): encoder, decoder = make_enc_dec( "stft", n_filters=n_filters, kernel_size=kernel_size, stride=stride, sample_rate=sample_rate, ) self.masker_kwargs = masker_kwargs masker = PhasenMasker(num_bins=(n_filters // 2 + 1), **masker_kwargs) super().__init__(encoder, masker, decoder)
def __init__(self, activation="relu", hidden_layers=(2048, 2048, 2048), padding=3, dropout=0.2, n_filters=256, kernel_size=256, stride=128, use_sigmoid=False, sample_rate=8000): encoder, decoder = make_enc_dec( "stft", n_filters=n_filters, kernel_size=kernel_size, stride=stride, sample_rate=sample_rate, ) self.activation = activation self.padding = padding self.hidden_layers = hidden_layers self.dropout = dropout self.n_freq = n_filters // 2 + 1 prev_dim = self.n_freq * (self.padding * 2 + 1) layers = [] for n_hid in self.hidden_layers: layers.append(nn.Linear(prev_dim, n_hid)) layers.append(activations.get(activation)()) if dropout != 0: layers.append(nn.Dropout(dropout)) prev_dim = n_hid layers.append(nn.Linear(prev_dim, self.n_freq)) if use_sigmoid: layers.append(nn.Sigmoid()) masker = nn.Sequential(*layers) super().__init__(encoder, masker, decoder) self.register_buffer('scaler_mean', torch.zeros(self.n_freq)) self.register_buffer('scaler_std', torch.zeros(self.n_freq)) self.has_scaler = False
def __init__( self, fb_name="free", kernel_size=16, n_filters=32, stride=8, encoder_activation=None, **fb_kwargs, ): encoder, decoder = make_enc_dec(fb_name, kernel_size=kernel_size, n_filters=n_filters, stride=stride, **fb_kwargs) masker = torch.nn.Identity() super().__init__(encoder, masker, decoder, encoder_activation=encoder_activation)
def __init__( self, activation="tanh", n_filters=256, kernel_size=256, stride=128, padding=3, hid_dim = 2048, z_dim = 64, sample_rate=8000 ): self.padding = padding self.n_freq = n_filters // 2 + 1 prev_dim = self.n_freq * (self.padding * 2 + 1) self.hid_dim = hid_dim self.z_dim = z_dim encoder, decoder = make_enc_dec( "stft", n_filters=n_filters, kernel_size=kernel_size, stride=stride, sample_rate=sample_rate, ) #fake masker to compy with asteroid BaseEncoderMaskerDecoder masker = nn.Sequential(nn.Identity(54, unused_argument1=0.1, unused_argument2=False)) self.activation = activation super().__init__(encoder, masker, decoder) #real masker self.enc1 = nn.Linear(prev_dim, self.hid_dim) self.enc2 = nn.Linear(hid_dim, self.hid_dim) self.enc3 = nn.Linear(hid_dim, self.n_freq) self.enc_mu_logvar = nn.Linear(self.n_freq, self.z_dim) self.dec1 = nn.Linear(self.z_dim, self.hid_dim) self.dec2 = nn.Linear(self.hid_dim, self.n_freq) self.register_buffer('scaler_mean', torch.zeros(self.n_freq)) self.register_buffer('scaler_std', torch.zeros(self.n_freq)) self.has_scaler = False
def __init__( self, target_metric="ESTOI", mask_threshold=0.05, n_filters=512, kernel_size=512, stride=256, sample_rate=8000, ): assert target_metric in ("PESQ", "STOI", "ESTOI") self.target_metric = target_metric self.mask_threshold = mask_threshold encoder, decoder = make_enc_dec( "stft", n_filters=n_filters, kernel_size=kernel_size, stride=stride, sample_rate=sample_rate, ) n_dim = n_filters // 2 + 1 generator = Generator_Sigmoid_LSTM_Masker(n_dim=n_dim) discriminator = Discriminator_Stride1_SN() BaseEncoderMaskerDecoder.__init__(self, encoder, generator, decoder) self.generator = generator self.discriminator = discriminator if self.target_metric == "STOI": self.metric_module = MetricSTOI(self.sample_rate, False) elif self.target_metric == "ESTOI": self.metric_module = MetricSTOI(self.sample_rate, True) else: self.metric_module = MetricPESQ(self.sample_rate)
def __init__( self, n_src, out_chan=None, n_blocks=8, n_repeats=3, bn_chan=128, hid_chan=512, skip_chan=128, conv_kernel_size=3, norm_type="gLN", mask_act="sigmoid", in_chan=None, causal=False, fb_name="free", kernel_size=16, n_filters=512, stride=8, encoder_activation=None, sample_rate=8000, **fb_kwargs, ): encoder, decoder = make_enc_dec( fb_name, kernel_size=kernel_size, n_filters=n_filters, stride=stride, sample_rate=sample_rate, **fb_kwargs, ) n_feats = encoder.n_feats_out if in_chan is not None: assert in_chan == n_feats, ("Number of filterbank output channels" " and number of input channels should " "be the same. Received " f"{n_feats} and {in_chan}") if causal and norm_type not in ["cgLN", "cLN"]: norm_type = "cLN" warnings.warn( "In causal configuration cumulative layer normalization (cgLN)" "or channel-wise layer normalization (chanLN) " f"must be used. Changing {norm_type} to cLN") # Update in_chan masker = TDConvNet( n_feats, n_src, out_chan=out_chan, n_blocks=n_blocks, n_repeats=n_repeats, bn_chan=bn_chan, hid_chan=hid_chan, skip_chan=skip_chan, conv_kernel_size=conv_kernel_size, norm_type=norm_type, mask_act=mask_act, causal=causal, ) super().__init__(encoder, masker, decoder, encoder_activation=encoder_activation)
def __init__( self, n_src, n_heads=4, ff_hid=256, chunk_size=100, hop_size=None, # 50 n_repeats=6, # 2 norm_type="gLN", ff_activation="relu", encoder_activation="relu", mask_act="relu", # sigmoid bidirectional=True, dropout=0, in_chan=None, # 64 fb_name="free", kernel_size=16, n_filters=64, stride=8, sample_rate=8000, **fb_kwargs, # out_chan=64 ): # encoder and decoder are just two conv1d, and they have independent filterbanks # 16, 8, 64 filter # encoder: conv1d on (1batch, 1, time) -> (1batch, freq/chan, stft_time) # decoder: conv1dtranspose on (1batch, freq, stft_time) -> (1batch, 1, time) # transpose is gradient of conv, not real deconv encoder, decoder = make_enc_dec( fb_name, kernel_size=kernel_size, n_filters=n_filters, stride=stride, sample_rate=sample_rate, **fb_kwargs, ) # it is n_filters n_feats = encoder.n_feats_out if in_chan is not None: assert in_chan == n_feats, ("Number of filterbank output channels" " and number of input channels should " "be the same. Received " f"{n_feats} and {in_chan}") # Update in_chan masker = DPTransformer( n_feats, n_src, n_heads=n_heads, ff_hid=ff_hid, ff_activation=ff_activation, chunk_size=chunk_size, hop_size=hop_size, n_repeats=n_repeats, norm_type=norm_type, mask_act=mask_act, bidirectional=bidirectional, dropout=dropout, ) super().__init__(encoder, masker, decoder, encoder_activation=encoder_activation)
def test_make_enc_dec(who): fb_config = {"n_filters": 500, "kernel_size": 16, "stride": 8} enc, dec = make_enc_dec("free", who_is_pinv=who, **fb_config) enc, dec = make_enc_dec(FreeFB, who_is_pinv=who, **fb_config) assert enc.filterbank == fb.get(enc.filterbank)
from asteroid.dsp.beamforming import ( Beamformer, SCM, RTFMVDRBeamformer, SDWMWFBeamformer, GEVBeamformer, stable_cholesky, ) torch_has_complex_support = tuple(map( int, torch.__version__.split(".")[:2])) >= (1, 8) _stft, _istft = make_enc_dec("stft", kernel_size=512, n_filters=512, stride=128) stft = lambda x: tr.to_torch_complex(_stft(x)) istft = lambda x: _istft(tr.from_torch_complex(x)) @pytest.mark.skipif(not torch_has_complex_support, "No complex support ") def _default_beamformer_test(beamformer: Beamformer, n_mics=4, *args, **kwargs): scm = SCM() speech = torch.randn(1, n_mics, 16000 * 6) noise = torch.randn(1, n_mics, 16000 * 6) mix = speech + noise