def __call__(self, data_list): if self.num_gpus in [0, 1]: # adds cpu-only case batch = data_list_collater(data_list, otf_graph=self.otf_graph) return [batch] else: num_devices = min(self.num_gpus, len(data_list)) count = torch.tensor([data.num_nodes for data in data_list]) cumsum = count.cumsum(0) cumsum = torch.cat([cumsum.new_zeros(1), cumsum], dim=0) device_id = ( num_devices * cumsum.to(torch.float) / cumsum[-1].item() ) device_id = (device_id[:-1] + device_id[1:]) / 2.0 device_id = device_id.to(torch.long) split = device_id.bincount().cumsum(0) split = torch.cat([split.new_zeros(1), split], dim=0) split = torch.unique(split, sorted=True) split = split.tolist() return [ data_list_collater(data_list[split[i] : split[i + 1]]) for i in range(len(split) - 1) ]
def test_energy_force_shape(self): data = self.data # Pass it through the model. out = self.model(data_list_collater([data])) # Compare shape of predicted energies, forces. energy = out[0].detach() np.testing.assert_equal(energy.shape, (1, 1)) forces = out[1].detach() np.testing.assert_equal(forces.shape, (data.pos.shape[0], 3))
def test_pbc_distances(self): data = self.data batch = data_list_collater([data] * 5) out = get_pbc_distances( batch.pos, batch.edge_index, batch.cell, batch.cell_offsets, batch.neighbors, ) edge_index, pbc_distances = out["edge_index"], out["distances"] np.testing.assert_array_equal( batch.edge_index, edge_index, ) np.testing.assert_array_almost_equal(batch.distances, pbc_distances)
def oc20_get_energy_and_forces(cell, atomic_numbers, positions): """ Predict total energy and atomic forces w/ pre-trained GNNP of OC20 (i.e. S2EF). Args: cell: lattice vectors in angstroms. atomic_numbers: atomic numbers for all atoms. positions: xyz coordinates for all atoms in angstroms. Returns: energy: total energy. forcces: atomic forces. """ # Initialize Atoms global myAtoms if myAtoms is None: myAtoms = Atoms(numbers=atomic_numbers, positions=positions, cell=cell, pbc=[True, True, True]) else: myAtoms.set_cell(cell) myAtoms.set_atomic_numbers(atomic_numbers) myAtoms.set_positions(positions) # Preprossing atomic positions (the edges on-the-fly) global myA2G data = myA2G.convert(myAtoms) batch = data_list_collater([data], otf_graph=True) # Predicting energy and forces global myTrainer predictions = myTrainer.predict(batch, per_image=False, disable_tqdm=True) energy = predictions["energy"].item() forces = predictions["forces"].cpu().numpy().tolist() return energy, forces
def test_rotation_invariance(self): data = self.data # Sampling a random rotation within [-180, 180] for all axes. transform = RandomRotate([-180, 180], [0, 1, 2]) data_rotated, rot, inv_rot = transform(data.clone()) assert not np.array_equal(data.pos, data_rotated.pos) # Pass it through the model. batch = data_list_collater([data, data_rotated]) out = self.model(batch) # Compare predicted energies and forces (after inv-rotation). energies = out[0].detach() np.testing.assert_almost_equal(energies[0], energies[1], decimal=5) forces = out[1].detach() np.testing.assert_array_almost_equal( forces[: forces.shape[0] // 2], torch.matmul(forces[forces.shape[0] // 2 :], inv_rot), decimal=5, )