Пример #1
0
    def __call__(self, data_list):
        if self.num_gpus in [0, 1]:  # adds cpu-only case
            batch = data_list_collater(data_list, otf_graph=self.otf_graph)
            return [batch]

        else:
            num_devices = min(self.num_gpus, len(data_list))

            count = torch.tensor([data.num_nodes for data in data_list])
            cumsum = count.cumsum(0)
            cumsum = torch.cat([cumsum.new_zeros(1), cumsum], dim=0)
            device_id = (
                num_devices * cumsum.to(torch.float) / cumsum[-1].item()
            )
            device_id = (device_id[:-1] + device_id[1:]) / 2.0
            device_id = device_id.to(torch.long)
            split = device_id.bincount().cumsum(0)
            split = torch.cat([split.new_zeros(1), split], dim=0)
            split = torch.unique(split, sorted=True)
            split = split.tolist()

            return [
                data_list_collater(data_list[split[i] : split[i + 1]])
                for i in range(len(split) - 1)
            ]
Пример #2
0
    def test_energy_force_shape(self):
        data = self.data

        # Pass it through the model.
        out = self.model(data_list_collater([data]))

        # Compare shape of predicted energies, forces.
        energy = out[0].detach()
        np.testing.assert_equal(energy.shape, (1, 1))

        forces = out[1].detach()
        np.testing.assert_equal(forces.shape, (data.pos.shape[0], 3))
Пример #3
0
    def test_pbc_distances(self):
        data = self.data
        batch = data_list_collater([data] * 5)
        out = get_pbc_distances(
            batch.pos,
            batch.edge_index,
            batch.cell,
            batch.cell_offsets,
            batch.neighbors,
        )
        edge_index, pbc_distances = out["edge_index"], out["distances"]

        np.testing.assert_array_equal(
            batch.edge_index,
            edge_index,
        )
        np.testing.assert_array_almost_equal(batch.distances, pbc_distances)
Пример #4
0
def oc20_get_energy_and_forces(cell, atomic_numbers, positions):
    """
    Predict total energy and atomic forces w/ pre-trained GNNP of OC20 (i.e. S2EF).
    Args:
        cell: lattice vectors in angstroms.
        atomic_numbers: atomic numbers for all atoms.
        positions: xyz coordinates for all atoms in angstroms.
    Returns:
        energy:  total energy.
        forcces: atomic forces.
    """

    # Initialize Atoms
    global myAtoms

    if myAtoms is None:
        myAtoms = Atoms(numbers=atomic_numbers,
                        positions=positions,
                        cell=cell,
                        pbc=[True, True, True])

    else:
        myAtoms.set_cell(cell)
        myAtoms.set_atomic_numbers(atomic_numbers)
        myAtoms.set_positions(positions)

    # Preprossing atomic positions (the edges on-the-fly)
    global myA2G

    data = myA2G.convert(myAtoms)
    batch = data_list_collater([data], otf_graph=True)

    # Predicting energy and forces
    global myTrainer

    predictions = myTrainer.predict(batch, per_image=False, disable_tqdm=True)

    energy = predictions["energy"].item()
    forces = predictions["forces"].cpu().numpy().tolist()

    return energy, forces
Пример #5
0
    def test_rotation_invariance(self):
        data = self.data

        # Sampling a random rotation within [-180, 180] for all axes.
        transform = RandomRotate([-180, 180], [0, 1, 2])
        data_rotated, rot, inv_rot = transform(data.clone())
        assert not np.array_equal(data.pos, data_rotated.pos)

        # Pass it through the model.
        batch = data_list_collater([data, data_rotated])
        out = self.model(batch)

        # Compare predicted energies and forces (after inv-rotation).
        energies = out[0].detach()
        np.testing.assert_almost_equal(energies[0], energies[1], decimal=5)

        forces = out[1].detach()
        np.testing.assert_array_almost_equal(
            forces[: forces.shape[0] // 2],
            torch.matmul(forces[forces.shape[0] // 2 :], inv_rot),
            decimal=5,
        )