def message(self, r, e, a): """weight_ij is the importance factor of node j to i weight_ji is the importance factor of node i to j Args: r (TYPE): Description e (TYPE): Description a (TYPE): Description Returns: TYPE: Description """ # i -> j weight_ij = torch.exp(self.activation(torch.cat((r[a[:, 0]], r[a[:, 1]]), dim=1) * \ self.weight).sum(-1)) # j -> i weight_ji = torch.exp(self.activation(torch.cat((r[a[:, 1]], r[a[:, 0]]), dim=1) * \ self.weight).sum(-1)) weight_ii = torch.exp(self.activation(torch.cat((r, r), dim=1) * \ self.weight).sum(-1)) normalization = scatter_add(weight_ij, a[:, 0], dim_size=r.shape[0]) \ + scatter_add(weight_ji, a[:, 1], dim_size=r.shape[0]) + weight_ii a_ij = weight_ij / normalization[a[:, 0]] # the importance of node j’s features to node i a_ji = weight_ji / normalization[a[:, 1]] # the importance of node i’s features to node j a_ii = weight_ii / normalization # self-attention message = r[a[:, 0]] * a_ij[:, None], \ r[a[:, 1]] * a_ij[:, None], \ r * a_ii[:, None] return message
def forward(self, xyz, bond_adj, bond_len, bond_par): e = ( xyz[bond_adj[:, 0]] - xyz[bond_adj[:, 1]] ).pow(2).sum(1).sqrt()[:, None] ebond = bond_par * (e - bond_len) ** 2 energy = 0.5 * scatter_add(src=ebond, index=bond_adj[:, 0], dim=0, dim_size=xyz.shape[0]) energy += 0.5 * scatter_add(src=ebond, index=bond_adj[:, 1], dim=0, dim_size=xyz.shape[0]) return energy
def forward(self, m_ji, e_rbf, nbr_list, num_atoms): # product of e and m prod = self.edge_dense(e_rbf) * m_ji # Convert messages to node features. # The messages are m = {m_ji} =, for example, # [m_{0,1}, m_{0,2}, m_{1,0}, m{2,0}], # with nbr_list = [[0, 1], [0, 2], [1,0], [2,0]]. # To sum over the j index we would have the first of # these messages add to index 1, the second to index 2, # and the last two to index 0. This means we use # nbr_list[:, 1] in the scatter addition. node_feats = scatter_add(prod.transpose(0, 1), nbr_list[:, 1], dim_size=num_atoms).transpose(0, 1) # Apply the dense layers for dense in self.dense_layers: node_feats = dense(node_feats) return node_feats
def forward(self, m_ji, e_rbf, a_sbf, kj_idx, ji_idx): """ Args: m_ji (torch.Tensor): edge vector e_rbf (torch.Tensor): radial basis representation of the distances a_sbf (torch.Tensor): spherical basis representation of the distances and angles kj_idx (torch.LongTensor): nbr_list indices corresponding to the k,j indices in the angle list. ji_idx (torch.LongTensor): nbr_list indices corresponding to the j,i indices in the angle list. Returns: out (torch.Tensor): aggregated angle and distance information to be added to m_ji. """ e_ji = self.e_dense(e_rbf[ji_idx]) m_kj = self.m_kj_dense(m_ji[kj_idx]) a = self.a_dense(a_sbf) edge_message = self.down_conv(m_kj * e_ji) aggr = edge_message * a out = self.up_conv( scatter_add(aggr.transpose(0, 1), ji_idx, dim_size=m_ji.shape[0]).transpose(0, 1)) return out
def forward(self, m_ji, e_rbf, a_sbf, kj_idx, ji_idx): """ Args: m_ji (torch.Tensor): edge vector e_rbf (torch.Tensor): radial basis representation of the distances a_sbf (torch.Tensor): spherical basis representation of the distances and angles kj_idx (torch.LongTensor): nbr_list indices corresponding to the k,j indices in the angle list. ji_idx (torch.LongTensor): nbr_list indices corresponding to the j,i indices in the angle list. Returns: out (torch.Tensor): aggregated angle and distance information to be added to m_ji. """ # apply the dense layers to e and m (ordered according to the kj # indices) and to a e_ji = self.e_dense(e_rbf[ji_idx]) m_kj = self.m_kj_dense(m_ji[kj_idx]) a = self.a_dense(a_sbf) # Defining e_m_kj = e_ji * m_kj and angle_len = len(angle_list), # this is equivalent to torch.stack([torch.matmul(torch.matmul( # w, e_m_kj[i]), a[i]) for i in range(angle_len)]). So what we're # doing is multiplying a matrix w of dimension (embed x bilin x embed) # first by e_m_kj [vector of dimension (embed)], giving a matrix of # dimension (embed x bilin). Then we multiply by `a` [vector of # dimension (bilin)], giving a vector of dimension (embed). We repeat # this for all the kj neighbors. This gives us `aggr`, a matrix of # dimension (angle_len x embed), i.e. a vector of dimension (embed) # for each angle. aggr = torch.einsum("wj,wl,ijl->wi", a, m_kj * e_ji, self.w) # Now we want to sum each fingerprint aggr_ijk # over k. Say aggr = [aggr[angle_list[0]], aggr[angle_list[1]]] # = [aggr_{0,1,2}, aggr_{0,1,3}] # = [aggr_{21, 10}, aggr_{31,10}]. # The way we know the ji corresponding # to each aggr_kj,ji is by noting that they have # the same ordering as `angle_list`, and that the ji index of # each element in `angle_list` is given by `ji_idx`. Hence we # use `scatter_add` with indices `ji_idx`, and give the resulting # vector the same dimension as m_ji. out = scatter_add(aggr.transpose(0, 1), ji_idx, dim_size=m_ji.shape[0]).transpose(0, 1) return out
def aggregate(self, message, index, size): # pdb.set_trace() new_r = scatter_add(src=message, index=index, dim=0, dim_size=size) return new_r
def V_ex(self, xyz, nbr_list, pbc_offsets): dist = (xyz[nbr_list[:, 1]] - xyz[nbr_list[:, 0]] + pbc_offsets).pow(2).sum(1).sqrt() potential = ((dist.reciprocal() * self.sigma).pow(self.power)) return scatter_add(potential, nbr_list[:, 0], dim_size=xyz.shape[0])[:, None]