Esempio n. 1
0
    def forward(self, data):
        self.device = data.pos.device
        self.num_atoms = len(data.batch)
        self.batch_size = len(data.natoms)

        atomic_numbers = data.atomic_numbers.long()
        pos = data.pos
        if self.regress_forces:
            pos = pos.requires_grad_(True)

        if self.otf_graph:
            edge_index, cell_offsets, neighbors = radius_graph_pbc(
                data, self.cutoff, 100)
            data.edge_index = edge_index
            data.cell_offsets = cell_offsets
            data.neighbors = neighbors

        if self.use_pbc:
            assert (atomic_numbers.dim() == 1
                    and atomic_numbers.dtype == torch.long)

            out = get_pbc_distances(
                pos,
                data.edge_index,
                data.cell,
                data.cell_offsets,
                data.neighbors,
                return_distance_vec=True,
            )

            edge_index = out["edge_index"]
            edge_distance = out["distances"]
            edge_distance_vec = out["distance_vec"]

        else:
            edge_index = radius_graph(pos, r=self.cutoff, batch=data.batch)
            j, i = edge_index
            edge_distance_vec = pos[j] - pos[i]
            edge_distance = edge_distance_vec.norm(dim=-1)

        edge_index, edge_distance, edge_distance_vec = self._filter_edges(
            edge_index,
            edge_distance,
            edge_distance_vec,
            self.max_num_neighbors,
        )

        outputs = self._forward_helper(data, edge_index, edge_distance,
                                       edge_distance_vec)
        if self.show_timing_info is True:
            torch.cuda.synchronize()
            print("Memory: {}\t{}\t{}".format(
                len(edge_index[0]),
                torch.cuda.memory_allocated() / (1000 * len(edge_index[0])),
                torch.cuda.max_memory_allocated() / 1000000,
            ))

        return outputs
Esempio n. 2
0
 def update_graph(self, atoms):
     edge_index, cell_offsets, num_neighbors = radius_graph_pbc(
         atoms, 6, 50)
     atoms.edge_index = edge_index
     atoms.cell_offsets = cell_offsets
     atoms.neighbors = num_neighbors
     if self.transform is not None:
         atoms = self.transform(atoms)
     return atoms
Esempio n. 3
0
    def forward(self, data):
        # Get node features
        if self.embedding.device != data.atomic_numbers.device:
            self.embedding = self.embedding.to(data.atomic_numbers.device)
        data.x = self.embedding[data.atomic_numbers.long() - 1]

        pos = data.pos
        if self.regress_forces:
            pos = pos.requires_grad_(True)

        if self.otf_graph:
            edge_index, cell_offsets, neighbors = radius_graph_pbc(
                data, self.cutoff, 50, data.pos.device
            )
            data.edge_index = edge_index
            data.cell_offsets = cell_offsets
            data.neighbors = neighbors

        if self.use_pbc:
            out = get_pbc_distances(
                pos,
                data.edge_index,
                data.cell,
                data.cell_offsets,
                data.neighbors,
            )

            data.edge_index = out["edge_index"]
            distances = out["distances"]
        else:
            data.edge_index = radius_graph(
                data.pos, r=self.cutoff, batch=data.batch
            )
            row, col = data.edge_index
            distances = (pos[row] - pos[col]).norm(dim=-1)

        data.edge_attr = self.distance_expansion(distances)
        # Forward pass through the network
        mol_feats = self._convolve(data)
        mol_feats = self.conv_to_fc(mol_feats)
        if hasattr(self, "fcs"):
            mol_feats = self.fcs(mol_feats)

        energy = self.fc_out(mol_feats)
        if self.regress_forces:
            forces = -1 * (
                torch.autograd.grad(
                    energy,
                    pos,
                    grad_outputs=torch.ones_like(energy),
                    create_graph=True,
                )[0]
            )
            return energy, forces
        else:
            return energy
Esempio n. 4
0
    def forward(self, data):
        z = data.atomic_numbers.long()
        pos = data.pos
        if self.regress_forces:
            pos = pos.requires_grad_(True)
        batch = data.batch

        if self.otf_graph:
            edge_index, cell_offsets, neighbors = radius_graph_pbc(
                data, self.cutoff, 50, data.pos.device)
            data.edge_index = edge_index
            data.cell_offsets = cell_offsets
            data.neighbors = neighbors

        # TODO return distance computation in radius_graph_pbc to remove need
        # for get_pbc_distances call
        if self.use_pbc:
            assert z.dim() == 1 and z.dtype == torch.long

            out = get_pbc_distances(
                pos,
                data.edge_index,
                data.cell,
                data.cell_offsets,
                data.neighbors,
            )

            edge_index = out["edge_index"]
            edge_weight = out["distances"]
            edge_attr = self.distance_expansion(edge_weight)

            h = self.embedding(z)
            for interaction in self.interactions:
                h = h + interaction(h, edge_index, edge_weight, edge_attr)

            h = self.lin1(h)
            h = self.act(h)
            h = self.lin2(h)

            batch = torch.zeros_like(z) if batch is None else batch
            energy = scatter(h, batch, dim=0, reduce=self.readout)
        else:
            energy = super(SchNetWrap, self).forward(z, pos, batch)

        if self.regress_forces:
            forces = -1 * (torch.autograd.grad(
                energy,
                pos,
                grad_outputs=torch.ones_like(energy),
                create_graph=True,
            )[0])
            return energy, forces
        else:
            return energy
Esempio n. 5
0
    def calculate(self, atoms, properties, system_changes):
        Calculator.calculate(self, atoms, properties, system_changes)
        data_object = self.a2g.convert(atoms)
        batch = data_list_collater([data_object])
        if self.pbc_graph:
            edge_index, cell_offsets, neighbors = radius_graph_pbc(
                batch, 6, 50, batch.pos.device)
            batch.edge_index = edge_index
            batch.cell_offsets = cell_offsets
            batch.neighbors = neighbors
        predictions = self.trainer.predict(batch)

        self.results["energy"] = predictions["energy"][0]
        self.results["forces"] = predictions["forces"][0]
Esempio n. 6
0
    def calculate(self, atoms, properties, system_changes):
        Calculator.calculate(self, atoms, properties, system_changes)
        data_object = self.a2g.convert(atoms)
        batch = data_list_collater([data_object])
        if self.pbc_graph:
            edge_index, cell_offsets, neighbors = radius_graph_pbc(
                batch, 6, 50, batch.pos.device)
            batch.edge_index = edge_index
            batch.cell_offsets = cell_offsets
            batch.neighbors = neighbors

        predictions = self.trainer.predict(batch, per_image=False)
        if self.trainer.name == "s2ef":
            self.results["energy"] = predictions["energy"].item()
            self.results["forces"] = predictions["forces"].cpu().numpy()

        elif self.trainer.name == "is2re":
            self.results["energy"] = predictions["energy"].item()
Esempio n. 7
0
    def _forward(self, data):
        pos = data.pos
        batch = data.batch

        if self.otf_graph:
            edge_index, cell_offsets, neighbors = radius_graph_pbc(
                data, self.cutoff, 50, data.pos.device)
            data.edge_index = edge_index
            data.cell_offsets = cell_offsets
            data.neighbors = neighbors

        if self.use_pbc:
            out = get_pbc_distances(
                pos,
                data.edge_index,
                data.cell,
                data.cell_offsets,
                data.neighbors,
                return_offsets=True,
            )

            edge_index = out["edge_index"]
            cell_offsets = out["offsets"]
        else:
            edge_index = radius_graph(pos, r=self.cutoff, batch=batch)
            raise NotImplementedError

        h = self.atom_map[data.atomic_numbers.long()]
        """ Propagate messages along edges and average over energies"""
        h = self.embedding_mlp(h)
        for i in range(self.hidden_layer):
            h, pos = self.egnn[i](h, pos, edge_index, cell_offsets, batch)

        out = self.head_pre_pool(h)
        out = global_mean_pool(out, batch)
        energy = self.head_post_pool(out)

        return energy
Esempio n. 8
0
    def _forward(self, data):
        pos = data.pos
        batch = data.batch

        if self.otf_graph:
            edge_index, cell_offsets, neighbors = radius_graph_pbc(
                data, self.cutoff, 50, data.pos.device
            )
            data.edge_index = edge_index
            data.cell_offsets = cell_offsets
            data.neighbors = neighbors

        if self.use_pbc:
            out = get_pbc_distances(
                pos,
                data.edge_index,
                data.cell,
                data.cell_offsets,
                data.neighbors,
                return_offsets=True,
            )

            edge_index = out["edge_index"]
            cell_offsets = out["offsets"]
        else:
            edge_index = radius_graph(pos, r=self.cutoff, batch=batch)
            raise NotImplementedError

        h = self.atom_map[data.atomic_numbers.long()]
        h = self.embedding(h)
        for layer in self.layers:
            h, pos = layer(h, pos, edge_index, cell_offsets)
        # Output heads
        h = self.head_pre_pool(h)
        h = global_mean_pool(h, batch)
        energy = self.head_post_pool(h)

        return energy
Esempio n. 9
0
File: gemnet.py Progetto: wood-b/ocp
    def generate_interaction_graph(self, data):
        num_atoms = data.atomic_numbers.size(0)

        if self.otf_graph:
            edge_index, cell_offsets, neighbors = radius_graph_pbc(
                data, self.cutoff, self.max_neighbors)
        else:
            edge_index = data.edge_index
            cell_offsets = data.cell_offsets
            neighbors = data.neighbors

        # Switch the indices, so the second one becomes the target index,
        # over which we can efficiently aggregate.
        out = get_pbc_distances(
            data.pos,
            edge_index,
            data.cell,
            cell_offsets,
            neighbors,
            return_offsets=True,
            return_distance_vec=True,
        )

        edge_index = out["edge_index"]
        D_st = out["distances"]
        # These vectors actually point in the opposite direction.
        # But we want to use col as idx_t for efficient aggregation.
        V_st = -out["distance_vec"] / D_st[:, None]
        # offsets_ca = -out["offsets"]  # a - c + offset

        # Mask interaction edges if required
        if self.otf_graph or np.isclose(self.cutoff, 6):
            select_cutoff = None
        else:
            select_cutoff = self.cutoff
        (
            edge_index,
            cell_offsets,
            neighbors,
            D_st,
            V_st,
        ) = self.select_edges(
            data=data,
            edge_index=edge_index,
            cell_offsets=cell_offsets,
            neighbors=neighbors,
            edge_dist=D_st,
            edge_vector=V_st,
            cutoff=select_cutoff,
        )

        (
            edge_index,
            cell_offsets,
            neighbors,
            D_st,
            V_st,
        ) = self.reorder_symmetric_edges(edge_index, cell_offsets, neighbors,
                                         D_st, V_st)

        # Indices for swapping c->a and a->c (for symmetric MP)
        block_sizes = neighbors // 2
        id_swap = repeat_blocks(
            block_sizes,
            repeats=2,
            continuous_indexing=False,
            start_idx=block_sizes[0],
            block_inc=block_sizes[:-1] + block_sizes[1:],
            repeat_inc=-block_sizes,
        )

        id3_ba, id3_ca, id3_ragged_idx = self.get_triplets(edge_index,
                                                           num_atoms=num_atoms)

        return (
            edge_index,
            neighbors,
            D_st,
            V_st,
            id_swap,
            id3_ba,
            id3_ca,
            id3_ragged_idx,
        )
Esempio n. 10
0
    def _forward(self, data):
        pos = data.pos
        batch = data.batch

        if self.otf_graph:
            edge_index, cell_offsets, neighbors = radius_graph_pbc(
                data, self.cutoff, 50, data.pos.device
            )
            data.edge_index = edge_index
            data.cell_offsets = cell_offsets
            data.neighbors = neighbors

        if self.use_pbc:
            out = get_pbc_distances(
                pos,
                data.edge_index,
                data.cell,
                data.cell_offsets,
                data.neighbors,
                return_offsets=True,
            )

            edge_index = out["edge_index"]
            dist = out["distances"]
            offsets = out["offsets"]

            j, i = edge_index
        else:
            edge_index = radius_graph(pos, r=self.cutoff, batch=batch)
            j, i = edge_index
            dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt()

        _, _, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets(
            edge_index, num_nodes=data.atomic_numbers.size(0)
        )

        # Calculate angles.
        pos_i = pos[idx_i].detach()
        pos_j = pos[idx_j].detach()
        if self.use_pbc:
            pos_ji, pos_kj = (
                pos[idx_j].detach() - pos_i + offsets[idx_ji],
                pos[idx_k].detach() - pos_j + offsets[idx_kj],
            )
        else:
            pos_ji, pos_kj = (
                pos[idx_j].detach() - pos_i,
                pos[idx_k].detach() - pos_j,
            )

        a = (pos_ji * pos_kj).sum(dim=-1)
        b = torch.cross(pos_ji, pos_kj).norm(dim=-1)
        angle = torch.atan2(b, a)

        rbf = self.rbf(dist)
        sbf = self.sbf(dist, angle, idx_kj)

        # Embedding block.
        x = self.emb(data.atomic_numbers.long(), rbf, i, j)
        P = self.output_blocks[0](x, rbf, i, num_nodes=pos.size(0))

        # Interaction blocks.
        for interaction_block, output_block in zip(
            self.interaction_blocks, self.output_blocks[1:]
        ):
            x = interaction_block(x, rbf, sbf, idx_kj, idx_ji)
            P += output_block(x, rbf, i, num_nodes=pos.size(0))

        energy = (
            P.sum(dim=0)
            if data.batch is None
            else scatter(P, data.batch, dim=0)
        )

        return P, energy
Esempio n. 11
0
    def _forward(self, data):

        pos = data.pos
        batch = data.batch

        if self.otf_graph:
            edge_index, cell_offsets, neighbors = radius_graph_pbc(
                data, self.cutoff, 50, data.pos.device
            )
            data.edge_index = edge_index
            data.cell_offsets = cell_offsets
            data.neighbors = neighbors

        if self.use_pbc:
            out = get_pbc_distances(
                pos,
                data.edge_index,
                data.cell,
                data.cell_offsets,
                data.neighbors,
                return_offsets=True,
            )

            edge_index = out["edge_index"]
            cell_offsets = out["offsets"]
        else:
            edge_index = radius_graph(pos, r=self.cutoff, batch=batch)
            raise NotImplementedError

        # construct the node and edge attributes
        rel_pos = (pos[edge_index[0]] - pos[edge_index[1]]) + cell_offsets
        edge_dist = rel_pos.pow(2).sum(-1, keepdims=True)

        edge_dist_radii_1 = edge_dist - self.atom_radii[data.atomic_numbers.long()[edge_index[0]]][:, None]
        edge_dist_radii_2 = edge_dist - self.atom_radii[data.atomic_numbers.long()[edge_index[1]]][:, None]
        edge_dist_radii_12 = edge_dist - self.atom_radii[data.atomic_numbers.long()[edge_index[0]]][:, None] - self.atom_radii[
            data.atomic_numbers.long()[edge_index[1]]][:, None]

        edge_attr = spherical_harmonics(self.attr_irreps, rel_pos, normalize=True, normalization='component')
        node_attr = self.node_attribute_net(edge_index, edge_attr)
        if (data.contains_isolated_nodes() and edge_index.max().item() + 1 != data.num_nodes):
            nr_add_attr = data.num_nodes - (edge_index.max().item() + 1)
            add_attr = node_attr.new_tensor(np.tile(np.eye(node_attr.shape[-1])[0,:], (nr_add_attr,1)))
            #add_attr = node_attr.new_tensor(np.zeros((nr_add_attr, node_attr.shape[-1])))
            node_attr = torch.cat((node_attr, add_attr), -2)

        # node_attr, edge_attr = self.attribute_net(pos, edge_index)
        x = self.atom_map[data.atomic_numbers.long()]
        x = self.embedding_layer_1(x, node_attr)
        x = self.embedding_layer_2(x, node_attr)
        x = self.embedding_layer_3(x, node_attr)

        # The main layers
        for layer in self.layers:
            x, pos = layer(x, pos, edge_index, edge_dist, edge_dist_radii_1, edge_dist_radii_2, edge_dist_radii_12, edge_attr, node_attr)

        # Output head
        x = self.head_pre_pool_layer_1(x, node_attr)
        x = self.head_pre_pool_layer_2(x)
        x = global_mean_pool(x, batch)
        x = self.head_post_pool_layer_1(x)
        x = self.head_post_pool_layer_2(x)

        # Return the result
        return x
Esempio n. 12
0
    def forward(self, data):
        z = data.atomic_numbers.long()

        pos = data.pos
        batch = data.batch

        if self.feat == "simple":
            h = self.embedding(z)
        elif self.feat == "full":
            h = self.embedding(self.atom_map[z])
        else:
            raise RuntimeError("Undefined feature type for atom")

        if self.otf_graph:
            edge_index, cell_offsets, neighbors = radius_graph_pbc(
                data, self.cutoff, 50
            )
            data.edge_index = edge_index
            data.cell_offsets = cell_offsets
            data.neighbors = neighbors

        out = get_pbc_distances(
            pos,
            data.edge_index,
            data.cell,
            data.cell_offsets,
            data.neighbors,
            return_distance_vec=True,
        )

        edge_index = out["edge_index"]
        edge_dist = out["distances"]
        edge_vec = out["distance_vec"]

        if self.pbc_apply_sph_harm:
            edge_vec_normalized = edge_vec / edge_dist.view(-1, 1)
            edge_attr_sph = self.pbc_sph(edge_vec_normalized)

        # calculate the edge weight according to the dist
        edge_weight = torch.cos(0.5 * edge_dist * PI / self.cutoff)

        # normalized edge vectors
        edge_vec_normalized = edge_vec / edge_dist.view(-1, 1)

        # edge distance, taking the atom_radii into account
        # each element lies in [0,1]
        edge_dist_list = (
            torch.stack(
                [
                    edge_dist,
                    edge_dist - self.atom_radii[z[edge_index[0]]],
                    edge_dist - self.atom_radii[z[edge_index[1]]],
                    edge_dist
                    - self.atom_radii[z[edge_index[0]]]
                    - self.atom_radii[z[edge_index[1]]],
                ]
            ).transpose(0, 1)
            / self.cutoff
        )

        if self.ablation == "nodistlist":
            edge_dist_list = edge_dist_list[:, 0].view(-1, 1)

        # make sure distance is positive
        edge_dist_list[edge_dist_list < 1e-3] = 1e-3

        # squash to [0,1] for gaussian basis
        if self.basis_type == "gauss":
            edge_vec_normalized = (edge_vec_normalized + 1) / 2.0

        # process raw_edge_attributes to generate edge_attributes
        if self.ablation == "onlydist":
            raw_edge_attr = edge_dist_list
        else:
            raw_edge_attr = torch.cat(
                [edge_vec_normalized, edge_dist_list], dim=1
            )

        if "sph" in self.basis_type:
            edge_attr = self.basis_fun(raw_edge_attr, edge_attr_sph)
        else:
            edge_attr = self.basis_fun(raw_edge_attr)

        # pass edge_attributes through interaction blocks
        for i, interaction in enumerate(self.interactions):
            h = h + interaction(h, edge_index, edge_attr, edge_weight)

        h = self.lin(h)
        h = self.activation(h)

        out = scatter(h, batch, dim=0, reduce="add")

        force = self.decoder(h)
        energy = self.energy_mlp(out)
        return energy, force
Esempio n. 13
0
    def forward(self, data):
        pos = data.pos
        if self.regress_forces:
            pos = pos.requires_grad_(True)
        batch = data.batch

        if self.otf_graph:
            edge_index, cell_offsets, neighbors = radius_graph_pbc(
                data, self.cutoff, 50, data.pos.device)
            data.edge_index = edge_index
            data.cell_offsets = cell_offsets
            data.neighbors = neighbors

        if self.use_pbc:
            out = get_pbc_distances(
                pos,
                data.edge_index,
                data.cell,
                data.cell_offsets,
                data.neighbors,
                return_offsets=True,
            )

            edge_index = out["edge_index"]
            dist = out["distances"]
            offsets = out["offsets"]

            j, i = edge_index
        else:
            edge_index = radius_graph(pos, r=self.cutoff, batch=batch)
            j, i = edge_index
            dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt()

        _, _, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets(
            edge_index, num_nodes=data.atomic_numbers.size(0))

        # Cap no. of triplets during training.
        if self.training:
            sub_ix = torch.randperm(idx_i.size(0))[:self.max_angles_per_image *
                                                   data.natoms.size(0)]
            idx_i, idx_j, idx_k = idx_i[sub_ix], idx_j[sub_ix], idx_k[sub_ix]
            idx_kj, idx_ji = idx_kj[sub_ix], idx_ji[sub_ix]

        # Calculate angles.
        pos_i = pos[idx_i].detach()
        pos_j = pos[idx_j].detach()
        if self.use_pbc:
            pos_ji, pos_kj = (
                pos[idx_j].detach() - pos_i + offsets[idx_ji],
                pos[idx_k].detach() - pos_j + offsets[idx_kj],
            )
        else:
            pos_ji, pos_kj = (
                pos[idx_j].detach() - pos_i,
                pos[idx_k].detach() - pos_j,
            )

        a = (pos_ji * pos_kj).sum(dim=-1)
        b = torch.cross(pos_ji, pos_kj).norm(dim=-1)
        angle = torch.atan2(b, a)

        rbf = self.rbf(dist)
        sbf = self.sbf(dist, angle, idx_kj)

        # Embedding block.
        x = self.emb(data.atomic_numbers.long(), rbf, i, j)
        P = self.output_blocks[0](x, rbf, i, num_nodes=pos.size(0))

        # Interaction blocks.
        for interaction_block, output_block in zip(self.interaction_blocks,
                                                   self.output_blocks[1:]):
            x = interaction_block(x, rbf, sbf, idx_kj, idx_ji)
            P += output_block(x, rbf, i, num_nodes=pos.size(0))

        energy = P.sum(dim=0) if batch is None else scatter(P, batch, dim=0)

        if self.regress_forces:
            forces = -1 * (torch.autograd.grad(
                energy,
                pos,
                grad_outputs=torch.ones_like(energy),
                create_graph=True,
            )[0])
            return energy, forces
        else:
            return energy