Beispiel #1
0
    def _single_attention(
            self,
            query: torch.Tensor,
            key: torch.Tensor,
            value: torch.Tensor,
            mask: torch.Tensor,
            adj_matrix: np.ndarray,
            distance_matrix: np.ndarray,
            dropout_p: float = 0.0,
            eps: float = 1e-6,
            inf: float = 1e12) -> Tuple[torch.Tensor, torch.Tensor]:
        """Defining and computing output for a single MAT attention layer.
    Parameters
    ----------
    query: torch.Tensor
      Standard query parameter for attention.
    key: torch.Tensor
      Standard key parameter for attention.
    value: torch.Tensor
      Standard value parameter for attention.
    mask: torch.Tensor
      Masks out padding values so that they are not taken into account when computing the attention score.
    adj_matrix: np.ndarray
      Adjacency matrix of the input molecule, returned from dc.feat.MATFeaturizer()
    dist_matrix: np.ndarray
      Distance matrix of the input molecule, returned from dc.feat.MATFeaturizer()
    dropout_p: float
      Dropout probability.
    eps: float
      Epsilon value
    inf: float
      Value of infinity to be used.
    """
        d_k = query.size(-1)
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

        if mask is not None:
            scores = scores.masked_fill(
                mask.unsqueeze(1).repeat(1, query.shape[1], query.shape[2],
                                         1) == 0, -inf)
        p_attn = F.softmax(scores, dim=-1)

        adj_matrix = adj_matrix / (
            torch.sum(torch.tensor(adj_matrix), dim=-1).unsqueeze(1) + eps)
        p_adj = adj_matrix.repeat(1, query.shape[1], 1, 1)

        distance_matrix = torch.tensor(distance_matrix).masked_fill(
            mask.repeat(1, mask.shape[-1], 1) == 0, np.inf)
        distance_matrix = self.dist_kernel(distance_matrix)
        p_dist = distance_matrix.unsqueeze(1).repeat(1, query.shape[1], 1, 1)
        p_weighted = self.lambda_attention * p_attn + self.lambda_distance * p_dist + self.lambda_adjacency * p_adj
        p_weighted = self.dropout_p(p_weighted)

        bd = value.broadcast_to(p_weighted.shape)
        return torch.matmul(p_weighted.float(), bd.float()), p_attn