Beispiel #1
0
    def __init__(
        self,
        hidden_size,
        dropout=0.0,
        bidirectional=False,
        mem_type="hc",
        norm_type="cLN",
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.bidirectional = bidirectional
        self.input_size = (int(bidirectional) + 1) * hidden_size
        self.mem_type = mem_type

        assert mem_type in [
            "hc",
            "h",
            "c",
            "id",
        ], f"only support 'hc', 'h', 'c' and 'id', current type: {mem_type}"

        if mem_type in ["hc", "h"]:
            self.h_net = SingleRNN(
                "LSTM",
                input_size=self.input_size,
                hidden_size=self.hidden_size,
                dropout=dropout,
                bidirectional=bidirectional,
            )
            self.h_norm = choose_norm(
                norm_type=norm_type, channel_size=self.input_size, shape="BTD"
            )
        if mem_type in ["hc", "c"]:
            self.c_net = SingleRNN(
                "LSTM",
                input_size=self.input_size,
                hidden_size=self.hidden_size,
                dropout=dropout,
                bidirectional=bidirectional,
            )
            self.c_norm = choose_norm(
                norm_type=norm_type, channel_size=self.input_size, shape="BTD"
            )
Beispiel #2
0
    def __init__(
        self,
        rnn_type,
        input_size,
        att_heads,
        hidden_size,
        dropout=0.0,
        activation="relu",
        bidirectional=True,
        norm="gLN",
    ):
        super().__init__()

        rnn_type = rnn_type.upper()
        assert rnn_type in [
            "RNN",
            "LSTM",
            "GRU",
        ], f"Only support 'RNN', 'LSTM' and 'GRU', current type: {rnn_type}"
        self.rnn_type = rnn_type

        self.att_heads = att_heads
        self.self_attn = nn.MultiheadAttention(input_size, att_heads, dropout=dropout)
        self.dropout = nn.Dropout(p=dropout)
        self.norm_attn = choose_norm(norm, input_size)

        self.rnn = getattr(nn, rnn_type)(
            input_size,
            hidden_size,
            1,
            batch_first=True,
            bidirectional=bidirectional,
        )

        activation = get_activation(activation)
        hdim = 2 * hidden_size if bidirectional else hidden_size
        self.feed_forward = nn.Sequential(
            activation, nn.Dropout(p=dropout), nn.Linear(hdim, input_size)
        )

        self.norm_ff = choose_norm(norm, input_size)
Beispiel #3
0
    def __init__(
        self, input_size, hidden_size, dropout=0.0, bidirectional=False, norm_type="cLN"
    ):
        super().__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_direction = int(bidirectional) + 1

        self.lstm = nn.LSTM(
            input_size,
            hidden_size,
            1,
            batch_first=True,
            bidirectional=bidirectional,
        )
        self.dropout = nn.Dropout(p=dropout)
        self.proj = nn.Linear(hidden_size * self.num_direction, input_size)
        self.norm = choose_norm(
            norm_type=norm_type, channel_size=input_size, shape="BTD"
        )
Beispiel #4
0
    def __init__(
        self,
        input_dim: int,
        post_enc_relu: bool = True,
        rnn_type: str = "lstm",
        bidirectional: bool = True,
        num_spk: int = 2,
        unit: int = 256,
        att_heads: int = 4,
        dropout: float = 0.0,
        activation: str = "relu",
        norm_type: str = "gLN",
        layer: int = 6,
        segment_size: int = 20,
        nonlinear: str = "relu",
    ):
        """Dual-Path Transformer Network (DPTNet) Separator

        Args:
            input_dim: input feature dimension
            rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'.
            bidirectional: bool, whether the inter-chunk RNN layers are bidirectional.
            num_spk: number of speakers
            unit: int, dimension of the hidden state.
            att_heads: number of attention heads.
            dropout: float, dropout ratio. Default is 0.
            activation: activation function applied at the output of RNN.
            norm_type: type of normalization to use after each inter- or
                intra-chunk Transformer block.
            nonlinear: the nonlinear function for mask estimation,
                       select from 'relu', 'tanh', 'sigmoid'
            layer: int, number of stacked RNN layers. Default is 3.
            segment_size: dual-path segment size
        """
        super().__init__()

        self._num_spk = num_spk
        self.segment_size = segment_size

        self.post_enc_relu = post_enc_relu
        self.enc_LN = choose_norm(norm_type, input_dim)
        self.dptnet = DPTNet(
            rnn_type=rnn_type,
            input_size=input_dim,
            hidden_size=unit,
            output_size=input_dim * num_spk,
            att_heads=att_heads,
            dropout=dropout,
            activation=activation,
            num_layers=layer,
            bidirectional=bidirectional,
            norm_type=norm_type,
        )
        # gated output layer
        self.output = torch.nn.Sequential(
            torch.nn.Conv1d(input_dim, input_dim, 1), torch.nn.Tanh())
        self.output_gate = torch.nn.Sequential(
            torch.nn.Conv1d(input_dim, input_dim, 1), torch.nn.Sigmoid())

        if nonlinear not in ("sigmoid", "relu", "tanh"):
            raise ValueError("Not supporting nonlinear={}".format(nonlinear))

        self.nonlinear = {
            "sigmoid": torch.nn.Sigmoid(),
            "relu": torch.nn.ReLU(),
            "tanh": torch.nn.Tanh(),
        }[nonlinear]