Esempio n. 1
0
    def __init__(
        self, kernel_size, stride, in_chan, n_src, bn_chan, chunk_size, hop_size, mask_act
    ):
        super(SingleDecoder, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.in_chan = in_chan
        self.bn_chan = bn_chan
        self.chunk_size = chunk_size
        self.hop_size = hop_size
        self.n_src = n_src
        self.mask_act = mask_act

        # Masking in 3D space
        net_out_conv = nn.Conv2d(bn_chan, n_src * bn_chan, 1)
        self.first_out = nn.Sequential(nn.PReLU(), net_out_conv)
        # Gating and masking in 2D space (after fold)
        self.net_out = nn.Sequential(nn.Conv1d(bn_chan, bn_chan, 1), nn.Tanh())
        self.net_gate = nn.Sequential(nn.Conv1d(bn_chan, bn_chan, 1), nn.Sigmoid())
        self.mask_net = nn.Conv1d(bn_chan, in_chan, 1, bias=False)

        # Get activation function.
        mask_nl_class = activations.get(mask_act)
        # For softmax, feed the source dimension.
        if has_arg(mask_nl_class, "dim"):
            self.output_act = mask_nl_class(dim=1)
        else:
            self.output_act = mask_nl_class()

        _, self.trans_conv = make_enc_dec(
            "free", kernel_size=kernel_size, stride=stride, n_filters=in_chan
        )
Esempio n. 2
0
    def __init__(
        self,
        architecture,
        stft_n_filters=1024,
        stft_kernel_size=1024,
        stft_stride=256,
        sample_rate=16000.0,
        **masknet_kwargs,
    ):
        self.architecture = architecture
        self.stft_n_filters = stft_n_filters
        self.stft_kernel_size = stft_kernel_size
        self.stft_stride = stft_stride
        self.masknet_kwargs = masknet_kwargs

        encoder, decoder = make_enc_dec(
            "stft",
            n_filters=stft_n_filters,
            kernel_size=stft_kernel_size,
            stride=stft_stride,
            sample_rate=sample_rate,
        )
        masker = self.masknet_class.default_architecture(
            architecture, **masknet_kwargs)
        super().__init__(encoder, masker, decoder)
Esempio n. 3
0
    def __init__(
        self,
        activation='tanh',
        latent_dim=64,
        hidden_dim_encoder=128,
        n_filters=1024,
        kernel_size=1024,
        stride=256,
        padding=3,
        sample_rate=8000
    ):

        encoder, decoder = make_enc_dec(
            "stft",
            n_filters=n_filters,
            kernel_size=kernel_size,
            stride=stride,
            sample_rate=sample_rate,
        )
        
        self.activation = activation
        self.latent_dim = latent_dim
        self.hidden_dim_encoder = hidden_dim_encoder
        self.input_dim = n_filters // 2 + 1
        self.padding = padding
        
        masker = VAE_inner(self.input_dim, self.latent_dim, self.hidden_dim_encoder, self.padding, self.activation)
        super().__init__(encoder, masker, decoder)
        
        self.register_buffer('scaler_mean', torch.zeros(self.input_dim))
        self.register_buffer('scaler_std', torch.zeros(self.input_dim))
        self.has_scaler = False
Esempio n. 4
0
def make_model_and_optimizer(conf):
    """Function to define the model and optimizer for a config dictionary.
    Args:
        conf: Dictionary containing the output of hierachical argparse.
    Returns:
        model, optimizer.
    The main goal of this function is to make reloading for resuming
    and evaluation very simple.
    """
    # Define building blocks for local model
    # The encoder and decoder can directly be made from the dictionary.
    encoder, decoder = fb.make_enc_dec(**conf["filterbank"])

    # The input post-processing changes the dimensions of input features to
    # the mask network. Different type of masks impose different output
    # dimensions to the mask network's output. We correct for these here.
    nn_in = int(encoder.n_feats_out * encoder.in_chan_mul)
    nn_out = int(encoder.n_feats_out * encoder.out_chan_mul)
    masker = TDConvNet(in_chan=nn_in, out_chan=nn_out, **conf["masknet"])
    # Another possibility is to correct for these effects inside of Model,
    # but then instantiation of masker should also be done inside.
    model = Model(encoder, masker, decoder)

    # The model is defined in Container, which is passed to DataParallel.

    # Define optimizer : can be instantiate from dictonary as well.
    optimizer = make_optimizer(model.parameters(), **conf["optim"])
    return model, optimizer
Esempio n. 5
0
def make_model_and_optimizer(conf):
    """Function to define the model and optimizer for a config dictionary.
    Args:
        conf: Dictionary containing the output of hierachical argparse.
    Returns:
        model, optimizer.
    The main goal of this function is to make reloading for resuming
    and evaluation very simple.
    """
    # Define building blocks for local model
    stft, istft = make_enc_dec("stft", **conf["filterbank"])
    # Because we concatenate (re, im, mag) as input and compute a complex mask.
    if conf["main_args"]["is_complex"]:
        inp_size = int(stft.n_feats_out * 3 / 2)
        output_size = stft.n_feats_out
    else:
        inp_size = output_size = int(stft.n_feats_out / 2)
    # Add these fields to the mask model dict
    conf["masknet"].update(dict(input_size=inp_size, output_size=output_size))
    masker = SimpleModel(**conf["masknet"])
    # Make the complete model
    model = Model(stft, masker, istft, is_complex=conf["main_args"]["is_complex"])
    # Define optimizer of this model
    optimizer = make_optimizer(model.parameters(), **conf["optim"])
    return model, optimizer
Esempio n. 6
0
 def __init__(
     self,
     n_src,
     out_chan=None,
     bn_chan=128,
     hid_size=128,
     chunk_size=100,
     hop_size=None,
     n_repeats=6,
     norm_type="gLN",
     mask_act="sigmoid",
     bidirectional=True,
     rnn_type="LSTM",
     num_layers=1,
     dropout=0,
     in_chan=None,
     fb_name="free",
     kernel_size=16,
     n_filters=64,
     stride=8,
     encoder_activation=None,
     sample_rate=8000,
     **fb_kwargs,
 ):
     encoder, decoder = make_enc_dec(
         fb_name,
         kernel_size=kernel_size,
         n_filters=n_filters,
         stride=stride,
         sample_rate=sample_rate,
         **fb_kwargs,
     )
     n_feats = encoder.n_feats_out
     if in_chan is not None:
         assert in_chan == n_feats, ("Number of filterbank output channels"
                                     " and number of input channels should "
                                     "be the same. Received "
                                     f"{n_feats} and {in_chan}")
     # Update in_chan
     masker = DPRNN(
         n_feats,
         n_src,
         out_chan=out_chan,
         bn_chan=bn_chan,
         hid_size=hid_size,
         chunk_size=chunk_size,
         hop_size=hop_size,
         n_repeats=n_repeats,
         norm_type=norm_type,
         mask_act=mask_act,
         bidirectional=bidirectional,
         rnn_type=rnn_type,
         num_layers=num_layers,
         dropout=dropout,
     )
     super().__init__(encoder,
                      masker,
                      decoder,
                      encoder_activation=encoder_activation)
Esempio n. 7
0
def test_perfect_istft_default_parameters(fb_config):
    """ Unit test perfect reconstruction with default values. """
    kernel_size = fb_config["kernel_size"]
    enc, dec = make_enc_dec("stft", **fb_config)
    inp_wav = torch.randn(2, 1, 32000)
    out_wav = dec(enc(inp_wav))[:, :, kernel_size:-kernel_size]
    inp_test = inp_wav[:, :, kernel_size:-kernel_size]
    testing.assert_allclose(inp_test, out_wav)
Esempio n. 8
0
    def __init__(
        self,
        n_srcs,
        bn_chan=128,
        hid_size=128,
        chunk_size=100,
        hop_size=None,
        n_repeats=6,
        norm_type="gLN",
        mask_act="sigmoid",
        bidirectional=True,
        rnn_type="LSTM",
        num_layers=1,
        dropout=0,
        kernel_size=16,
        n_filters=64,
        stride=8,
        encoder_activation=None,
        use_mulcat=False,
        sample_rate=8000,
    ):
        super().__init__(sample_rate=sample_rate)
        self.encoder_activation = encoder_activation
        self.enc_activation = activations.get(encoder_activation or "linear")()
        hop_size = hop_size if hop_size is not None else chunk_size // 2
        self.encoder, _ = make_enc_dec(
            "free",
            kernel_size=kernel_size,
            n_filters=n_filters,
            stride=stride,
        )
        # Update in_chan
        self.masker = DPRNN_MultiStage(
            in_chan=n_filters,
            bn_chan=bn_chan,
            hid_size=hid_size,
            chunk_size=chunk_size,
            hop_size=hop_size,
            n_repeats=n_repeats,
            norm_type=norm_type,
            bidirectional=bidirectional,
            rnn_type=rnn_type,
            use_mulcat=use_mulcat,
            num_layers=num_layers,
            dropout=dropout,
        )
        self.decoder_select = Decoder_Select(
            kernel_size=kernel_size,
            stride=stride,
            in_chan=n_filters,
            n_srcs=n_srcs,
            bn_chan=bn_chan,
            chunk_size=chunk_size,
            hop_size=hop_size,
            mask_act=mask_act,
        )

        """
Esempio n. 9
0
    def __init__(self,
                 target="TCS",
                 inner_channels=64,
                 dilated_layers=10,
                 total_layers=13,
                 max_dilation=None,
                 n_filters=2048,
                 kernel_size=2048,
                 stride=1024,
                 sample_rate=8000):
        encoder, decoder = make_enc_dec(
            "stft",
            n_filters=n_filters,
            kernel_size=kernel_size,
            stride=stride,
            sample_rate=sample_rate,
        )

        assert target in ("TMS", "TCS", "cIRM")
        self.target = target
        self.inner_channels = inner_channels
        self.dilated_layers = dilated_layers
        self.total_layers = total_layers
        self.max_dilation = max_dilation

        input_channels = 2 if self.target in ("cIRM", "TCS") else 1

        layers = []
        prev_ch = input_channels

        num_square_layers = self.total_layers - self.dilated_layers
        assert num_square_layers > 0

        for idx in range(self.dilated_layers):
            dilation = 2**idx
            if self.max_dilation is not None:
                dilation = min(dilation, self.max_dilation)

            layers.append(
                SMoLnetDilatedLayer(prev_ch,
                                    self.inner_channels,
                                    dilation=dilation))
            prev_ch = self.inner_channels

        for idx in range(num_square_layers):
            layers.append(
                SMoLnetLateLayer(self.inner_channels, self.inner_channels))

        if self.target == "TMS":
            layers.append(nn.Conv2d(self.inner_channels, 1, kernel_size=1))
            layers.append(nn.Softplus())
        else:
            layers.append(nn.Conv2d(self.inner_channels, 2, kernel_size=1))

        masker = nn.Sequential(*layers)
        super().__init__(encoder, masker, decoder)
Esempio n. 10
0
def test_stft_def(fb_config):
    """ Check consistency between two calls."""
    fb = STFTFB(**fb_config)
    enc = Encoder(fb)
    dec = Decoder(fb)
    enc2, dec2 = make_enc_dec("stft", **fb_config)
    testing.assert_allclose(enc.filterbank.filters(),
                            enc2.filterbank.filters())
    testing.assert_allclose(dec.filterbank.filters(),
                            dec2.filterbank.filters())
Esempio n. 11
0
 def __init__(
     self,
     n_src,
     n_heads=4,
     ff_hid=256,
     chunk_size=100,
     hop_size=None,
     n_repeats=6,
     norm_type="gLN",
     ff_activation="relu",
     encoder_activation="relu",
     mask_act="relu",
     bidirectional=True,
     dropout=0,
     in_chan=None,
     fb_name="free",
     kernel_size=16,
     n_filters=64,
     stride=8,
     sample_rate=8000,
     **fb_kwargs,
 ):
     encoder, decoder = make_enc_dec(
         fb_name,
         kernel_size=kernel_size,
         n_filters=n_filters,
         stride=stride,
         sample_rate=sample_rate,
         **fb_kwargs,
     )
     n_feats = encoder.n_feats_out
     if in_chan is not None:
         assert in_chan == n_feats, ("Number of filterbank output channels"
                                     " and number of input channels should "
                                     "be the same. Received "
                                     f"{n_feats} and {in_chan}")
     # Update in_chan
     masker = DPTransformer(
         n_feats,
         n_src,
         n_heads=n_heads,
         ff_hid=ff_hid,
         ff_activation=ff_activation,
         chunk_size=chunk_size,
         hop_size=hop_size,
         n_repeats=n_repeats,
         norm_type=norm_type,
         mask_act=mask_act,
         bidirectional=bidirectional,
         dropout=dropout,
     )
     super().__init__(encoder,
                      masker,
                      decoder,
                      encoder_activation=encoder_activation)
Esempio n. 12
0
 def __init__(
     self,
     n_src,
     out_chan=None,
     n_blocks=8,
     n_repeats=3,
     bn_chan=128,
     hid_chan=512,
     skip_chan=128,
     conv_kernel_size=3,
     norm_type="gLN",
     mask_act="sigmoid",
     in_chan=None,
     causal=False,
     fb_name="free",
     kernel_size=16,
     n_filters=512,
     stride=8,
     encoder_activation=None,
     sample_rate=8000,
     **fb_kwargs,
 ):
     encoder, decoder = make_enc_dec(
         fb_name,
         kernel_size=kernel_size,
         n_filters=n_filters,
         stride=stride,
         sample_rate=sample_rate,
         **fb_kwargs,
     )
     n_feats = encoder.n_feats_out
     if in_chan is not None:
         assert in_chan == n_feats, ("Number of filterbank output channels"
                                     " and number of input channels should "
                                     "be the same. Received "
                                     f"{n_feats} and {in_chan}")
     # Update in_chan
     masker = TDConvNet(
         n_feats,
         n_src,
         out_chan=out_chan,
         n_blocks=n_blocks,
         n_repeats=n_repeats,
         bn_chan=bn_chan,
         hid_chan=hid_chan,
         skip_chan=skip_chan,
         conv_kernel_size=conv_kernel_size,
         norm_type=norm_type,
         mask_act=mask_act,
         causal=causal,
     )
     super().__init__(encoder,
                      masker,
                      decoder,
                      encoder_activation=encoder_activation)
Esempio n. 13
0
    def __init__(
        self,
        input_type="mag",
        output_type="mag",
        hidden_dims=(1024, ),
        dropout=0.0,
        activation="relu",
        mask_act="relu",
        norm_type="gLN",
        fb_name="stft",
        n_filters=512,
        stride=256,
        kernel_size=512,
        sample_rate=16000,
        **fb_kwargs,
    ):
        fb_type = fb_kwargs.pop("fb_type", None)
        if fb_type:
            warnings.warn(
                "Using `fb_type` keyword argument is deprecated and "
                "will be removed in v0.4.0. Use `fb_name` instead.",
                VisibleDeprecationWarning,
            )
            fb_name = fb_type
        encoder, decoder = make_enc_dec(
            fb_name,
            kernel_size=kernel_size,
            n_filters=n_filters,
            stride=stride,
            sample_rate=sample_rate,
            **fb_kwargs,
        )

        n_masker_in = self._get_n_feats_input(input_type, encoder.n_feats_out)
        n_masker_out = self._get_n_feats_output(output_type,
                                                encoder.n_feats_out)
        masker = build_demask_masker(
            n_masker_in,
            n_masker_out,
            norm_type=norm_type,
            activation=activation,
            hidden_dims=hidden_dims,
            dropout=dropout,
            mask_act=mask_act,
        )
        super().__init__(encoder, masker, decoder)

        self.input_type = input_type
        self.output_type = output_type
        self.hidden_dims = hidden_dims
        self.dropout = dropout
        self.activation = activation
        self.mask_act = mask_act
        self.norm_type = norm_type
Esempio n. 14
0
    def __init__(
        self,
        n_src,
        out_chan=None,
        rnn_type="lstm",
        n_layers=4,
        hid_size=512,
        dropout=0.3,
        mask_act="sigmoid",
        bidirectional=True,
        in_chan=None,
        fb_name="free",
        n_filters=64,
        kernel_size=16,
        stride=8,
        encoder_activation=None,
        sample_rate=8000,
        **fb_kwargs,
    ):
        encoder, decoder = make_enc_dec(
            fb_name,
            kernel_size=kernel_size,
            n_filters=n_filters,
            stride=stride,
            sample_rate=sample_rate,
            **fb_kwargs,
        )
        n_feats = encoder.n_feats_out
        if in_chan is not None:
            assert in_chan == n_feats, ("Number of filterbank output channels"
                                        " and number of input channels should "
                                        "be the same. Received "
                                        f"{n_feats} and {in_chan}")

        # Real gated encoder
        encoder = _GatedEncoder(encoder)

        # Masker
        masker = LSTMMasker(
            n_feats,
            n_src,
            out_chan=out_chan,
            hid_size=hid_size,
            mask_act=mask_act,
            bidirectional=bidirectional,
            rnn_type=rnn_type,
            n_layers=n_layers,
            dropout=dropout,
        )
        super().__init__(encoder,
                         masker,
                         decoder,
                         encoder_activation=encoder_activation)
Esempio n. 15
0
def make_model_and_optimizer(conf):
    """Function to define the model and optimizer for a config dictionary.
    Args:
        conf: Dictionary containing the output of hierachical argparse.
    Returns:
        model, optimizer.
    The main goal of this function is to make reloading for resuming
    and evaluation very simple.
    """
    enc, dec = fb.make_enc_dec("stft", **conf["filterbank"])
    masker = Chimera(enc.n_feats_out // 2, **conf["masknet"])
    model = Model(enc, masker, dec)
    optimizer = make_optimizer(model.parameters(), **conf["optim"])
    return model, optimizer
def test_jit_filterbanks_enc(filter_bank_name, inference_data):
    n_filters = 32
    if filter_bank_name == TorchSTFTFB:
        kernel_size = n_filters
    else:
        kernel_size = n_filters // 2
    enc, _ = make_enc_dec(filter_bank_name, n_filters=n_filters, kernel_size=kernel_size)

    inputs = ((torch.rand(1, 200) - 0.5) * 2,)
    traced = torch.jit.trace(enc, inputs)
    with torch.no_grad():
        res = enc(inference_data)
        out = traced(inference_data)
        print(traced.code_with_constants)
        print(traced.code)
        assert_allclose(res, out)
 def __init__(
     self,
     fb_name="free",
     kernel_size=16,
     n_filters=32,
     stride=8,
     **fb_kwargs,
 ):
     super().__init__()
     if fb_name == TorchSTFTFB:
         n_filters = kernel_size
     encoder, decoder = make_enc_dec(
         fb_name, kernel_size=kernel_size, n_filters=n_filters, stride=stride, **fb_kwargs
     )
     self.encoder = encoder
     self.decoder = decoder
Esempio n. 18
0
def make_model_and_optimizer(conf):
    """Function to define the model and optimizer for a config dictionary.
    Args:
        conf: Dictionary containing the output of hierachical argparse.
    Returns:
        model, optimizer.
    The main goal of this function is to make reloading for resuming
    and evaluation very simple.
    """
    # Define building blocks for local model
    enc, dec = fb.make_enc_dec("free", **conf["filterbank"])
    masker = DPRNN(**conf["masknet"])
    model = Model(enc, masker, dec)
    # Define optimizer of this model
    optimizer = make_optimizer(model.parameters(), **conf["optim"])
    return model, optimizer
Esempio n. 19
0
    def __init__(
        self,
        n_src,
        bn_chan=128,
        num_blocks=16,
        upsampling_depth=4,
        mask_act="relu",
        in_chan=None,
        fb_name="free",
        kernel_size=21,
        n_filters=512,
        stride=None,
        sample_rate=8000,
        **fb_kwargs,
    ):
        stride = kernel_size // 2 if not stride else stride
        # Need the encoder to determine the number of input channels
        enc, dec = make_enc_dec(
            fb_name,
            kernel_size=kernel_size,
            n_filters=n_filters,
            stride=stride,
            sample_rate=sample_rate,
            padding=kernel_size // 2,
            output_padding=(kernel_size // 2) - 1,
            **fb_kwargs,
        )
        n_feats = enc.n_feats_out
        enc = _Padder(enc,
                      upsampling_depth=upsampling_depth,
                      kernel_size=kernel_size)

        if in_chan is not None:
            assert in_chan == n_feats, ("Number of filterbank output channels"
                                        " and number of input channels should "
                                        "be the same. Received "
                                        f"{n_feats} and {in_chan}")

        masker = SuDORMRFImproved(
            n_feats,
            n_src,
            bn_chan=bn_chan,
            num_blocks=num_blocks,
            upsampling_depth=upsampling_depth,
            mask_act=mask_act,
        )
        super().__init__(enc, masker, dec, encoder_activation=None)
Esempio n. 20
0
    def __init__(
        self,
        input_type="mag",
        output_type="mag",
        hidden_dims=(1024, ),
        dropout=0.0,
        activation="relu",
        mask_act="relu",
        norm_type="gLN",
        fb_name="stft",
        n_filters=512,
        stride=256,
        kernel_size=512,
        sample_rate=16000,
        **fb_kwargs,
    ):
        encoder, decoder = make_enc_dec(
            fb_name,
            kernel_size=kernel_size,
            n_filters=n_filters,
            stride=stride,
            sample_rate=sample_rate,
            **fb_kwargs,
        )

        n_masker_in = self._get_n_feats_input(input_type, encoder.n_feats_out)
        n_masker_out = self._get_n_feats_output(output_type,
                                                encoder.n_feats_out)
        masker = build_demask_masker(
            n_masker_in,
            n_masker_out,
            norm_type=norm_type,
            activation=activation,
            hidden_dims=hidden_dims,
            dropout=dropout,
            mask_act=mask_act,
        )
        super().__init__(encoder, masker, decoder)

        self.input_type = input_type
        self.output_type = output_type
        self.hidden_dims = hidden_dims
        self.dropout = dropout
        self.activation = activation
        self.mask_act = mask_act
        self.norm_type = norm_type
Esempio n. 21
0
    def __init__(self,
                 median_kernel_size=11,
                 n_filters=2048,
                 kernel_size=2048,
                 stride=512,
                 sample_rate=8000):
        encoder, decoder = make_enc_dec(
            "stft",
            n_filters=n_filters,
            kernel_size=kernel_size,
            stride=stride,
            sample_rate=sample_rate,
        )

        self.median_kernel_size = median_kernel_size
        self.pad = kernel_size // 2
        super().__init__(encoder, nn.Identity(), decoder)
Esempio n. 22
0
    def __init__(self,
                 n_filters=512,
                 kernel_size=400,
                 stride=100,
                 sample_rate=8000,
                 **masker_kwargs):
        encoder, decoder = make_enc_dec(
            "stft",
            n_filters=n_filters,
            kernel_size=kernel_size,
            stride=stride,
            sample_rate=sample_rate,
        )

        self.masker_kwargs = masker_kwargs

        masker = PhasenMasker(num_bins=(n_filters // 2 + 1), **masker_kwargs)
        super().__init__(encoder, masker, decoder)
Esempio n. 23
0
    def __init__(self,
                 activation="relu",
                 hidden_layers=(2048, 2048, 2048),
                 padding=3,
                 dropout=0.2,
                 n_filters=256,
                 kernel_size=256,
                 stride=128,
                 use_sigmoid=False,
                 sample_rate=8000):
        encoder, decoder = make_enc_dec(
            "stft",
            n_filters=n_filters,
            kernel_size=kernel_size,
            stride=stride,
            sample_rate=sample_rate,
        )

        self.activation = activation
        self.padding = padding
        self.hidden_layers = hidden_layers
        self.dropout = dropout
        self.n_freq = n_filters // 2 + 1

        prev_dim = self.n_freq * (self.padding * 2 + 1)
        layers = []

        for n_hid in self.hidden_layers:
            layers.append(nn.Linear(prev_dim, n_hid))
            layers.append(activations.get(activation)())
            if dropout != 0:
                layers.append(nn.Dropout(dropout))
            prev_dim = n_hid

        layers.append(nn.Linear(prev_dim, self.n_freq))
        if use_sigmoid:
            layers.append(nn.Sigmoid())

        masker = nn.Sequential(*layers)
        super().__init__(encoder, masker, decoder)

        self.register_buffer('scaler_mean', torch.zeros(self.n_freq))
        self.register_buffer('scaler_std', torch.zeros(self.n_freq))
        self.has_scaler = False
Esempio n. 24
0
 def __init__(
     self,
     fb_name="free",
     kernel_size=16,
     n_filters=32,
     stride=8,
     encoder_activation=None,
     **fb_kwargs,
 ):
     encoder, decoder = make_enc_dec(fb_name,
                                     kernel_size=kernel_size,
                                     n_filters=n_filters,
                                     stride=stride,
                                     **fb_kwargs)
     masker = torch.nn.Identity()
     super().__init__(encoder,
                      masker,
                      decoder,
                      encoder_activation=encoder_activation)
Esempio n. 25
0
    def __init__(
        self,
        activation="tanh",
        n_filters=256,
        kernel_size=256,
        stride=128,
        padding=3,
        hid_dim = 2048,
        z_dim = 64,
        sample_rate=8000
    ):
        self.padding = padding
        self.n_freq = n_filters // 2 + 1 
        prev_dim = self.n_freq * (self.padding * 2 + 1)
        self.hid_dim = hid_dim 
        self.z_dim = z_dim

        encoder, decoder = make_enc_dec(
            "stft",
            n_filters=n_filters,
            kernel_size=kernel_size,
            stride=stride,
            sample_rate=sample_rate,
        )
        #fake masker to compy with asteroid BaseEncoderMaskerDecoder 
        masker = nn.Sequential(nn.Identity(54, unused_argument1=0.1, unused_argument2=False))
        
        self.activation = activation
        
        super().__init__(encoder, masker, decoder) 
        #real masker 
        self.enc1 = nn.Linear(prev_dim, self.hid_dim) 
        self.enc2 = nn.Linear(hid_dim, self.hid_dim) 
        self.enc3 = nn.Linear(hid_dim, self.n_freq) 
        self.enc_mu_logvar = nn.Linear(self.n_freq, self.z_dim)
        self.dec1 = nn.Linear(self.z_dim, self.hid_dim)
        self.dec2 = nn.Linear(self.hid_dim, self.n_freq)
        
        self.register_buffer('scaler_mean', torch.zeros(self.n_freq))
        self.register_buffer('scaler_std', torch.zeros(self.n_freq))
        self.has_scaler = False
Esempio n. 26
0
    def __init__(
        self,
        target_metric="ESTOI",
        mask_threshold=0.05,
        n_filters=512,
        kernel_size=512,
        stride=256,
        sample_rate=8000,
    ):
        assert target_metric in ("PESQ", "STOI", "ESTOI")
        self.target_metric = target_metric
        self.mask_threshold = mask_threshold

        encoder, decoder = make_enc_dec(
            "stft",
            n_filters=n_filters,
            kernel_size=kernel_size,
            stride=stride,
            sample_rate=sample_rate,
        )

        n_dim = n_filters // 2 + 1
        generator = Generator_Sigmoid_LSTM_Masker(n_dim=n_dim)
        discriminator = Discriminator_Stride1_SN()

        BaseEncoderMaskerDecoder.__init__(self, encoder, generator, decoder)

        self.generator = generator
        self.discriminator = discriminator

        if self.target_metric == "STOI":
            self.metric_module = MetricSTOI(self.sample_rate, False)
        elif self.target_metric == "ESTOI":
            self.metric_module = MetricSTOI(self.sample_rate, True)
        else:
            self.metric_module = MetricPESQ(self.sample_rate)
Esempio n. 27
0
 def __init__(
     self,
     n_src,
     out_chan=None,
     n_blocks=8,
     n_repeats=3,
     bn_chan=128,
     hid_chan=512,
     skip_chan=128,
     conv_kernel_size=3,
     norm_type="gLN",
     mask_act="sigmoid",
     in_chan=None,
     causal=False,
     fb_name="free",
     kernel_size=16,
     n_filters=512,
     stride=8,
     encoder_activation=None,
     sample_rate=8000,
     **fb_kwargs,
 ):
     encoder, decoder = make_enc_dec(
         fb_name,
         kernel_size=kernel_size,
         n_filters=n_filters,
         stride=stride,
         sample_rate=sample_rate,
         **fb_kwargs,
     )
     n_feats = encoder.n_feats_out
     if in_chan is not None:
         assert in_chan == n_feats, ("Number of filterbank output channels"
                                     " and number of input channels should "
                                     "be the same. Received "
                                     f"{n_feats} and {in_chan}")
     if causal and norm_type not in ["cgLN", "cLN"]:
         norm_type = "cLN"
         warnings.warn(
             "In causal configuration cumulative layer normalization (cgLN)"
             "or channel-wise layer normalization (chanLN)  "
             f"must be used. Changing {norm_type} to cLN")
     # Update in_chan
     masker = TDConvNet(
         n_feats,
         n_src,
         out_chan=out_chan,
         n_blocks=n_blocks,
         n_repeats=n_repeats,
         bn_chan=bn_chan,
         hid_chan=hid_chan,
         skip_chan=skip_chan,
         conv_kernel_size=conv_kernel_size,
         norm_type=norm_type,
         mask_act=mask_act,
         causal=causal,
     )
     super().__init__(encoder,
                      masker,
                      decoder,
                      encoder_activation=encoder_activation)
Esempio n. 28
0
 def __init__(
         self,
         n_src,
         n_heads=4,
         ff_hid=256,
         chunk_size=100,
         hop_size=None,  # 50
         n_repeats=6,  # 2
         norm_type="gLN",
         ff_activation="relu",
         encoder_activation="relu",
         mask_act="relu",  # sigmoid
         bidirectional=True,
         dropout=0,
         in_chan=None,  # 64
         fb_name="free",
         kernel_size=16,
         n_filters=64,
         stride=8,
         sample_rate=8000,
         **fb_kwargs,  # out_chan=64
 ):
     # encoder and decoder are just two conv1d, and they have independent filterbanks
     # 16, 8, 64 filter
     # encoder: conv1d on (1batch, 1, time) -> (1batch, freq/chan, stft_time)
     # decoder: conv1dtranspose on (1batch, freq, stft_time) -> (1batch, 1, time)
     # transpose is gradient of conv, not real deconv
     encoder, decoder = make_enc_dec(
         fb_name,
         kernel_size=kernel_size,
         n_filters=n_filters,
         stride=stride,
         sample_rate=sample_rate,
         **fb_kwargs,
     )
     # it is n_filters
     n_feats = encoder.n_feats_out
     if in_chan is not None:
         assert in_chan == n_feats, ("Number of filterbank output channels"
                                     " and number of input channels should "
                                     "be the same. Received "
                                     f"{n_feats} and {in_chan}")
     # Update in_chan
     masker = DPTransformer(
         n_feats,
         n_src,
         n_heads=n_heads,
         ff_hid=ff_hid,
         ff_activation=ff_activation,
         chunk_size=chunk_size,
         hop_size=hop_size,
         n_repeats=n_repeats,
         norm_type=norm_type,
         mask_act=mask_act,
         bidirectional=bidirectional,
         dropout=dropout,
     )
     super().__init__(encoder,
                      masker,
                      decoder,
                      encoder_activation=encoder_activation)
def test_make_enc_dec(who):
    fb_config = {"n_filters": 500, "kernel_size": 16, "stride": 8}
    enc, dec = make_enc_dec("free", who_is_pinv=who, **fb_config)
    enc, dec = make_enc_dec(FreeFB, who_is_pinv=who, **fb_config)
    assert enc.filterbank == fb.get(enc.filterbank)
Esempio n. 30
0
from asteroid.dsp.beamforming import (
    Beamformer,
    SCM,
    RTFMVDRBeamformer,
    SDWMWFBeamformer,
    GEVBeamformer,
    stable_cholesky,
)

torch_has_complex_support = tuple(map(
    int,
    torch.__version__.split(".")[:2])) >= (1, 8)

_stft, _istft = make_enc_dec("stft",
                             kernel_size=512,
                             n_filters=512,
                             stride=128)
stft = lambda x: tr.to_torch_complex(_stft(x))
istft = lambda x: _istft(tr.from_torch_complex(x))


@pytest.mark.skipif(not torch_has_complex_support, "No complex support ")
def _default_beamformer_test(beamformer: Beamformer,
                             n_mics=4,
                             *args,
                             **kwargs):
    scm = SCM()

    speech = torch.randn(1, n_mics, 16000 * 6)
    noise = torch.randn(1, n_mics, 16000 * 6)
    mix = speech + noise