Ejemplo n.º 1
0
    def __init__(self,
                 num_classes: int,                    # number of classfication
                 max_length: int = 120,               # a maximum allowed length for the sequence to be processed
                 hidden_dim: int = 1024,              # dimension of RNN`s hidden state vector
                 sos_id: int = 1,                     # start of sentence token`s id
                 eos_id: int = 2,                     # end of sentence token`s id
                 attn_mechanism: str = 'multi-head',  # type of attention mechanism
                 num_heads: int = 4,                  # number of attention heads
                 num_layers: int = 2,                 # number of RNN layers
                 rnn_type: str = 'lstm',              # type of RNN cell
                 dropout_p: float = 0.3,              # dropout probability
                 device: str = 'cuda') -> None:       # device - 'cuda' or 'cpu'
        super(SpeechDecoderRNN, self).__init__(hidden_dim, hidden_dim, num_layers, rnn_type, dropout_p, False, device)
        self.num_classes = num_classes
        self.num_heads = num_heads
        self.max_length = max_length
        self.eos_id = eos_id
        self.sos_id = sos_id
        self.attn_mechanism = attn_mechanism.lower()
        self.embedding = nn.Embedding(num_classes, hidden_dim)
        self.input_dropout = nn.Dropout(dropout_p)

        if self.attn_mechanism == 'loc':
            self.attention = AddNorm(LocationAwareAttention(hidden_dim, smoothing=True), hidden_dim)
        elif self.attn_mechanism == 'multi-head':
            self.attention = AddNorm(MultiHeadAttention(hidden_dim, num_heads), hidden_dim)
        elif self.attn_mechanism == 'additive':
            self.attention = AddNorm(AdditiveAttention(hidden_dim), hidden_dim)
        elif self.attn_mechanism == 'scaled-dot':
            self.attention = AddNorm(ScaledDotProductAttention(hidden_dim), hidden_dim)
        else:
            raise ValueError("Unsupported attention: %s".format(attn_mechanism))

        self.projection = AddNorm(Linear(hidden_dim, hidden_dim, bias=True), hidden_dim)
        self.generator = Linear(hidden_dim, num_classes, bias=False)
Ejemplo n.º 2
0
    def __init__(
        self,
        num_classes: int,
        max_length: int = 150,
        hidden_state_dim: int = 1024,
        pad_id: int = 0,
        sos_id: int = 1,
        eos_id: int = 2,
        attn_mechanism: str = 'multi-head',
        num_heads: int = 4,
        num_layers: int = 2,
        rnn_type: str = 'lstm',
        dropout_p: float = 0.3,
    ) -> None:
        super(DecoderRNN, self).__init__()
        self.hidden_state_dim = hidden_state_dim
        self.num_classes = num_classes
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.max_length = max_length
        self.eos_id = eos_id
        self.sos_id = sos_id
        self.pad_id = pad_id
        self.attn_mechanism = attn_mechanism.lower()
        self.embedding = nn.Embedding(num_classes, hidden_state_dim)
        self.input_dropout = nn.Dropout(dropout_p)
        rnn_cell = self.supported_rnns[rnn_type.lower()]
        self.rnn = rnn_cell(
            input_size=hidden_state_dim,
            hidden_size=hidden_state_dim,
            num_layers=num_layers,
            bias=True,
            batch_first=True,
            dropout=dropout_p,
            bidirectional=False,
        )

        if self.attn_mechanism == 'loc':
            self.attention = LocationAwareAttention(hidden_state_dim,
                                                    attn_dim=hidden_state_dim,
                                                    smoothing=False)
        elif self.attn_mechanism == 'multi-head':
            self.attention = MultiHeadAttention(hidden_state_dim,
                                                num_heads=num_heads)
        elif self.attn_mechanism == 'additive':
            self.attention = AdditiveAttention(hidden_state_dim)
        elif self.attn_mechanism == 'scaled-dot':
            self.attention = ScaledDotProductAttention(dim=hidden_state_dim)
        else:
            raise ValueError(
                "Unsupported attention: %s".format(attn_mechanism))

        self.fc = nn.Sequential(
            Linear(hidden_state_dim << 1, hidden_state_dim),
            nn.Tanh(),
            View(shape=(-1, self.hidden_state_dim), contiguous=True),
            Linear(hidden_state_dim, num_classes),
        )