Example #1
0
class TransformerEncoderDecoderConnectionBlock(nn.Module):
    def __init__(self, hidden_dim: int, key_query_value_dim: int = 64, num_attention_heads=8, with_hard_concrete_gate=False):
        super(TransformerEncoderDecoderConnectionBlock, self).__init__()

        self.multihead_attention = MultiHeadAttention(hidden_dim, key_and_query_dim=key_query_value_dim, value_dim=key_query_value_dim, num_heads=num_attention_heads, with_hard_concrete_gate=with_hard_concrete_gate)
        self.norm = nn.LayerNorm(hidden_dim)
        return

    def forward(self, encoder_outputs, decoder_hidden, mask=None):
        attention_outputs = self.multihead_attention.forward(q_hidden_inputs=decoder_hidden, k_hidden_inputs=encoder_outputs, v_hidden_inputs=encoder_outputs, mask=mask)
        return self.norm(decoder_hidden + attention_outputs)
Example #2
0
def test_multihead_attention():
    with torch.no_grad():
        kq_dim = 4
        v_dim = 8
        num_heads = 16
        hidden_dim = 64

        batch_size = 3
        seq_len = 7
        attention_input = torch.rand((batch_size, seq_len, hidden_dim))

        mha = MultiHeadAttention(hidden_dim,
                                 key_and_query_dim=kq_dim,
                                 value_dim=v_dim,
                                 num_heads=num_heads)
        mha_ouptut = mha.forward(q_hidden_inputs=attention_input,
                                 k_hidden_inputs=attention_input,
                                 v_hidden_inputs=attention_input,
                                 mask=None)
        assert mha_ouptut.size() == attention_input.size()