Exemplo n.º 1
0
    def message(self, r, e, a):
        """weight_ij is the importance factor of node j to i
           weight_ji is the importance factor of node i to j

        Args:
            r (TYPE): Description
            e (TYPE): Description
            a (TYPE): Description

        Returns:
            TYPE: Description
        """
        # i -> j
        weight_ij = torch.exp(self.activation(torch.cat((r[a[:, 0]], r[a[:, 1]]), dim=1) * \
                                              self.weight).sum(-1))
        # j -> i
        weight_ji = torch.exp(self.activation(torch.cat((r[a[:, 1]], r[a[:, 0]]), dim=1) * \
                                              self.weight).sum(-1))

        weight_ii = torch.exp(self.activation(torch.cat((r, r), dim=1) * \
                                              self.weight).sum(-1))

        normalization = scatter_add(weight_ij, a[:, 0], dim_size=r.shape[0]) \
                        + scatter_add(weight_ji, a[:, 1], dim_size=r.shape[0]) + weight_ii

        a_ij = weight_ij / normalization[a[:, 0]]  # the importance of node j’s features to node i
        a_ji = weight_ji / normalization[a[:, 1]]  # the importance of node i’s features to node j
        a_ii = weight_ii / normalization  # self-attention

        message = r[a[:, 0]] * a_ij[:, None], \
                  r[a[:, 1]] * a_ij[:, None], \
                  r * a_ii[:, None]

        return message
Exemplo n.º 2
0
    def forward(self, xyz, bond_adj, bond_len, bond_par):
        e = (
                    xyz[bond_adj[:, 0]] - xyz[bond_adj[:, 1]]
            ).pow(2).sum(1).sqrt()[:, None]

        ebond = bond_par * (e - bond_len) ** 2
        energy = 0.5 * scatter_add(src=ebond, index=bond_adj[:, 0], dim=0, dim_size=xyz.shape[0])
        energy += 0.5 * scatter_add(src=ebond, index=bond_adj[:, 1], dim=0, dim_size=xyz.shape[0])

        return energy
Exemplo n.º 3
0
    def forward(self, m_ji, e_rbf, nbr_list, num_atoms):

        # product of e and m

        prod = self.edge_dense(e_rbf) * m_ji

        # Convert messages to node features.
        # The messages are m = {m_ji} =, for example,
        # [m_{0,1}, m_{0,2}, m_{1,0}, m{2,0}],
        # with nbr_list = [[0, 1], [0, 2], [1,0], [2,0]].
        # To sum over the j index we would have the first of
        # these messages add to index 1, the second to index 2,
        # and the last two to index 0. This means we use
        # nbr_list[:, 1] in the scatter addition.

        node_feats = scatter_add(prod.transpose(0, 1),
                                 nbr_list[:, 1],
                                 dim_size=num_atoms).transpose(0, 1)

        # Apply the dense layers

        for dense in self.dense_layers:
            node_feats = dense(node_feats)

        return node_feats
Exemplo n.º 4
0
    def forward(self, m_ji, e_rbf, a_sbf, kj_idx, ji_idx):
        """
        Args:
            m_ji (torch.Tensor): edge vector
            e_rbf (torch.Tensor): radial basis representation
                of the distances
            a_sbf (torch.Tensor): spherical basis representation
                of the distances and angles
            kj_idx (torch.LongTensor): nbr_list indices corresponding
                to the k,j indices in the angle list.
            ji_idx (torch.LongTensor): nbr_list indices corresponding
                to the j,i indices in the angle list.
        Returns:
            out (torch.Tensor): aggregated angle and distance information
                to be added to m_ji.
        """

        e_ji = self.e_dense(e_rbf[ji_idx])
        m_kj = self.m_kj_dense(m_ji[kj_idx])
        a = self.a_dense(a_sbf)

        edge_message = self.down_conv(m_kj * e_ji)
        aggr = edge_message * a
        out = self.up_conv(
            scatter_add(aggr.transpose(0, 1), ji_idx,
                        dim_size=m_ji.shape[0]).transpose(0, 1))

        return out
Exemplo n.º 5
0
    def forward(self, m_ji, e_rbf, a_sbf, kj_idx, ji_idx):
        """
        Args:
            m_ji (torch.Tensor): edge vector
            e_rbf (torch.Tensor): radial basis representation
                of the distances
            a_sbf (torch.Tensor): spherical basis representation
                of the distances and angles
            kj_idx (torch.LongTensor): nbr_list indices corresponding
                to the k,j indices in the angle list.
            ji_idx (torch.LongTensor): nbr_list indices corresponding
                to the j,i indices in the angle list.
        Returns:
            out (torch.Tensor): aggregated angle and distance information
                to be added to m_ji.
        """

        # apply the dense layers to e and m (ordered according to the kj
        # indices) and to a

        e_ji = self.e_dense(e_rbf[ji_idx])
        m_kj = self.m_kj_dense(m_ji[kj_idx])
        a = self.a_dense(a_sbf)

        # Defining e_m_kj = e_ji * m_kj and angle_len = len(angle_list),
        # this is equivalent to  torch.stack([torch.matmul(torch.matmul(
        # w, e_m_kj[i]), a[i]) for i in range(angle_len)]). So what we're
        # doing is multiplying a matrix w of dimension (embed x bilin x embed)
        # first by e_m_kj [vector of dimension (embed)], giving a matrix of
        # dimension (embed x bilin). Then we multiply by `a` [vector of
        # dimension (bilin)], giving a vector of dimension (embed). We repeat
        # this for all the kj neighbors. This gives us `aggr`, a matrix of
        # dimension (angle_len x embed), i.e. a vector of dimension (embed)
        # for each angle.

        aggr = torch.einsum("wj,wl,ijl->wi", a, m_kj * e_ji, self.w)

        # Now we want to sum each fingerprint aggr_ijk
        # over k. Say aggr = [aggr[angle_list[0]], aggr[angle_list[1]]]
        # =  [aggr_{0,1,2}, aggr_{0,1,3}]
        # = [aggr_{21, 10}, aggr_{31,10}].
        # The way we know the ji corresponding
        # to each aggr_kj,ji is by noting that they have
        # the same ordering as `angle_list`, and that the ji index of
        # each element in `angle_list` is given by `ji_idx`. Hence we
        # use `scatter_add` with indices `ji_idx`, and give the resulting
        # vector the same dimension as m_ji.

        out = scatter_add(aggr.transpose(0, 1), ji_idx,
                          dim_size=m_ji.shape[0]).transpose(0, 1)

        return out
Exemplo n.º 6
0
 def aggregate(self, message, index, size):
     # pdb.set_trace()
     new_r = scatter_add(src=message, index=index, dim=0, dim_size=size)
     return new_r
Exemplo n.º 7
0
 def V_ex(self, xyz, nbr_list, pbc_offsets):
     dist = (xyz[nbr_list[:, 1]] - xyz[nbr_list[:, 0]] +
             pbc_offsets).pow(2).sum(1).sqrt()
     potential = ((dist.reciprocal() * self.sigma).pow(self.power))
     return scatter_add(potential, nbr_list[:, 0],
                        dim_size=xyz.shape[0])[:, None]