def _single_attention( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor, adj_matrix: np.ndarray, distance_matrix: np.ndarray, dropout_p: float = 0.0, eps: float = 1e-6, inf: float = 1e12) -> Tuple[torch.Tensor, torch.Tensor]: """Defining and computing output for a single MAT attention layer. Parameters ---------- query: torch.Tensor Standard query parameter for attention. key: torch.Tensor Standard key parameter for attention. value: torch.Tensor Standard value parameter for attention. mask: torch.Tensor Masks out padding values so that they are not taken into account when computing the attention score. adj_matrix: np.ndarray Adjacency matrix of the input molecule, returned from dc.feat.MATFeaturizer() dist_matrix: np.ndarray Distance matrix of the input molecule, returned from dc.feat.MATFeaturizer() dropout_p: float Dropout probability. eps: float Epsilon value inf: float Value of infinity to be used. """ d_k = query.size(-1) scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) if mask is not None: scores = scores.masked_fill( mask.unsqueeze(1).repeat(1, query.shape[1], query.shape[2], 1) == 0, -inf) p_attn = F.softmax(scores, dim=-1) adj_matrix = adj_matrix / ( torch.sum(torch.tensor(adj_matrix), dim=-1).unsqueeze(1) + eps) p_adj = adj_matrix.repeat(1, query.shape[1], 1, 1) distance_matrix = torch.tensor(distance_matrix).masked_fill( mask.repeat(1, mask.shape[-1], 1) == 0, np.inf) distance_matrix = self.dist_kernel(distance_matrix) p_dist = distance_matrix.unsqueeze(1).repeat(1, query.shape[1], 1, 1) p_weighted = self.lambda_attention * p_attn + self.lambda_distance * p_dist + self.lambda_adjacency * p_adj p_weighted = self.dropout_p(p_weighted) bd = value.broadcast_to(p_weighted.shape) return torch.matmul(p_weighted.float(), bd.float()), p_attn