Example #1
0
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_mask: Optional[torch.Tensor] = None,
        number_of_keys: int = -1,
        number_of_queries: int = -1,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        b = -1  # the batch size
        # This is to avoid using .size() when possible as Barracuda does not support
        n_q = number_of_queries if number_of_queries != -1 else query.size(1)
        n_k = number_of_keys if number_of_keys != -1 else key.size(1)

        query = self.fc_q(query)  # (b, n_q, h*d)
        key = self.fc_k(key)  # (b, n_k, h*d)
        value = self.fc_v(value)  # (b, n_k, h*d)

        query = query.reshape(b, n_q, self.n_heads, self.embedding_size)
        key = key.reshape(b, n_k, self.n_heads, self.embedding_size)
        value = value.reshape(b, n_k, self.n_heads, self.embedding_size)

        query = query.permute([0, 2, 1, 3])  # (b, h, n_q, emb)
        # The next few lines are equivalent to : key.permute([0, 2, 3, 1])
        # This is a hack, ONNX will compress two permute operations and
        # Barracuda will not like seeing `permute([0,2,3,1])`
        key = key.permute([0, 2, 1, 3])  # (b, h, emb, n_k)
        key -= 1
        key += 1
        key = key.permute([0, 1, 3, 2])  # (b, h, emb, n_k)

        qk = torch.matmul(query, key)  # (b, h, n_q, n_k)

        if key_mask is None:
            qk = qk / (self.embedding_size**0.5)
        else:
            key_mask = key_mask.reshape(b, 1, 1, n_k)
            qk = (1 - key_mask) * qk / (self.embedding_size**
                                        0.5) + key_mask * self.NEG_INF

        att = torch.softmax(qk, dim=3)  # (b, h, n_q, n_k)

        value = value.permute([0, 2, 1, 3])  # (b, h, n_k, emb)
        value_attention = torch.matmul(att, value)  # (b, h, n_q, emb)

        value_attention = value_attention.permute([0, 2, 1,
                                                   3])  # (b, n_q, h, emb)
        value_attention = value_attention.reshape(
            b, n_q, self.n_heads * self.embedding_size)  # (b, n_q, h*emb)

        out = self.fc_out(value_attention)  # (b, n_q, emb)
        return out, att
Example #2
0
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        n_q: int,
        n_k: int,
        key_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        b = -1  # the batch size

        query = query.reshape(b, n_q, self.n_heads,
                              self.head_size)  # (b, n_q, h, emb / h)
        key = key.reshape(b, n_k, self.n_heads,
                          self.head_size)  # (b, n_k, h, emb / h)
        value = value.reshape(b, n_k, self.n_heads,
                              self.head_size)  # (b, n_k, h, emb / h)

        query = query.permute([0, 2, 1, 3])  # (b, h, n_q, emb / h)
        # The next few lines are equivalent to : key.permute([0, 2, 3, 1])
        # This is a hack, ONNX will compress two permute operations and
        # Barracuda will not like seeing `permute([0,2,3,1])`
        key = key.permute([0, 2, 1, 3])  # (b, h, emb / h, n_k)
        key -= 1
        key += 1
        key = key.permute([0, 1, 3, 2])  # (b, h, emb / h, n_k)

        qk = torch.matmul(query, key)  # (b, h, n_q, n_k)

        if key_mask is None:
            qk = qk / (self.embedding_size**0.5)
        else:
            key_mask = key_mask.reshape(b, 1, 1, n_k)
            qk = (1 - key_mask) * qk / (self.embedding_size**
                                        0.5) + key_mask * self.NEG_INF

        att = torch.softmax(qk, dim=3)  # (b, h, n_q, n_k)

        value = value.permute([0, 2, 1, 3])  # (b, h, n_k, emb / h)
        value_attention = torch.matmul(att, value)  # (b, h, n_q, emb / h)

        value_attention = value_attention.permute([0, 2, 1,
                                                   3])  # (b, n_q, h, emb / h)
        value_attention = value_attention.reshape(
            b, n_q, self.embedding_size)  # (b, n_q, emb)

        return value_attention, att