def test_collect_atom_triples_batch(four_atoms, ase_env): # Get the first environment (two atoms) ase_env.cutoff = 1.1 nbh_1, offsets_1 = ase_env.get_environment(four_atoms) # Get the second environment (all_atoms) ase_env.cutoff = 3.0 nbh_2, offsets_2 = ase_env.get_environment(four_atoms) # Pad to same size (assumes -1 padding) max_atoms = max(nbh_1.shape[0], nbh_2.shape[0]) max_nbh = max(nbh_1.shape[1], nbh_2.shape[1]) tmp_1 = -np.ones((max_atoms, max_nbh)) tmp_2 = -np.ones((max_atoms, max_nbh)) tmp_1[:nbh_1.shape[0], :nbh_1.shape[1]] = nbh_1 tmp_2[:nbh_2.shape[0], :nbh_2.shape[1]] = nbh_2 nbh_1 = tmp_1 nbh_2 = tmp_2 # Get masks and pair indices nbh_mask_1 = (nbh_1 >= 0).astype(np.int) nbh_mask_2 = (nbh_2 >= 0).astype(np.int) nbh_1_j, nbh_1_k, offset_idx_1_j, offset_idx_1_k = env.collect_atom_triples( nbh_1) nbh_2_j, nbh_2_k, offset_idx_2_j, offset_idx_2_k = env.collect_atom_triples( nbh_2) # Get pairwise masks mask_1_j = np.take_along_axis(nbh_mask_1, offset_idx_1_j, axis=1) mask_1_k = np.take_along_axis(nbh_mask_1, offset_idx_1_k, axis=1) mask_1_jk = mask_1_j * mask_1_k mask_2_j = np.take_along_axis(nbh_mask_2, offset_idx_2_j, axis=1) mask_2_k = np.take_along_axis(nbh_mask_2, offset_idx_2_k, axis=1) mask_2_jk = mask_2_j * mask_2_k # Generate batches and convert to torch batch_nbh = torch.LongTensor(np.array([nbh_1, nbh_2])) batch_nbh_mask = torch.LongTensor(np.array([nbh_mask_1, nbh_mask_2])) batch_nbh_j = torch.LongTensor(np.array([nbh_1_j, nbh_2_j])) batch_nbh_k = torch.LongTensor(np.array([nbh_1_k, nbh_2_k])) batch_offset_idx_j = torch.LongTensor( np.array([offset_idx_1_j, offset_idx_2_j])) batch_offset_idx_k = torch.LongTensor( np.array([offset_idx_1_k, offset_idx_2_k])) batch_mask_jk = torch.LongTensor(np.array([mask_1_jk, mask_2_jk])) # Collect triples via batch method ( nbh_j, nbh_k, offset_idx_j, offset_idx_k, pair_mask, ) = env.collect_atom_triples_batch(batch_nbh, batch_nbh_mask) assert np.allclose(batch_nbh_j, nbh_j) assert np.allclose(batch_nbh_k, nbh_k) assert np.allclose(batch_offset_idx_j, offset_idx_j) assert np.allclose(batch_offset_idx_k, offset_idx_k) assert np.allclose(batch_mask_jk, pair_mask)
def test_collect_atom_triples(four_atoms, ase_env): # Get the environment nbh, offsets = ase_env.get_environment(four_atoms) # Generate general indices n_atoms, n_neighbors = nbh.shape idx_j = [] idx_k = [] for k in range(n_neighbors): for j in range(k + 1, n_neighbors): idx_j.append(j) idx_k.append(k) idx_j = np.array(idx_j) idx_k = np.array(idx_k) # Generate ase pair neighborhoods ase_nbh_j = nbh[:, idx_j] ase_nbh_k = nbh[:, idx_k] # Set up offset indices ase_off_idx_j = np.repeat(idx_j[None, :], n_atoms, axis=0) ase_off_idx_k = np.repeat(idx_k[None, :], n_atoms, axis=0) nbh_j, nbh_k, offset_idx_j, offset_idx_k = env.collect_atom_triples(nbh) assert np.allclose(ase_nbh_j, nbh_j) assert np.allclose(ase_nbh_k, nbh_k) assert np.allclose(ase_off_idx_j, offset_idx_j) assert np.allclose(ase_off_idx_k, offset_idx_k)
def __getitem__(self, idx): at, properties = self.get_properties(idx) # get atom environment nbh_idx, offsets = self.environment_provider.get_environment(at) properties[Structure.neighbors] = torch.LongTensor( nbh_idx.astype(np.int)) properties[Structure.cell_offset] = torch.FloatTensor( offsets.astype(np.float32)) properties["_idx"] = torch.LongTensor(np.array([idx], dtype=np.int)) if self.collect_triples: nbh_idx_j, nbh_idx_k, offset_idx_j, offset_idx_k = collect_atom_triples( nbh_idx) properties[Structure.neighbor_pairs_j] = torch.LongTensor( nbh_idx_j.astype(np.int)) properties[Structure.neighbor_pairs_k] = torch.LongTensor( nbh_idx_k.astype(np.int)) properties['offset_idx_j'] = torch.LongTensor( offset_idx_j.astype(np.int)) properties['offset_idx_k'] = torch.LongTensor( offset_idx_k.astype(np.int)) return properties
def _convert_atoms( atoms, environment_provider=SimpleEnvironmentProvider(), collect_triples=False, center_positions=False, output=None, ): """ Helper function to convert ASE atoms object to SchNetPack input format. Args: atoms (ase.Atoms): Atoms object of molecule environment_provider (callable): Neighbor list provider. device (str): Device for computation (default='cpu') output (dict): Destination for converted atoms, if not None Returns: dict of torch.Tensor: Properties including neighbor lists and masks reformated into SchNetPack input format. """ if output is None: inputs = {} else: inputs = output # Elemental composition cell = np.array(atoms.cell.array, dtype=np.float32) # get cell array inputs[Properties.Z] = torch.LongTensor(atoms.numbers.astype(np.int)) positions = atoms.positions.astype(np.float32) if center_positions: positions -= atoms.get_center_of_mass() inputs[Properties.R] = torch.FloatTensor(positions) inputs[Properties.cell] = torch.FloatTensor(cell) # get atom environment nbh_idx, offsets = environment_provider.get_environment(atoms) # Get neighbors and neighbor mask inputs[Properties.neighbors] = torch.LongTensor(nbh_idx.astype(np.int)) # Get cells inputs[Properties.cell] = torch.FloatTensor(cell) inputs[Properties.cell_offset] = torch.FloatTensor(offsets.astype(np.float32)) # If requested get neighbor lists for triples if collect_triples: nbh_idx_j, nbh_idx_k, offset_idx_j, offset_idx_k = collect_atom_triples(nbh_idx) inputs[Properties.neighbor_pairs_j] = torch.LongTensor(nbh_idx_j.astype(np.int)) inputs[Properties.neighbor_pairs_k] = torch.LongTensor(nbh_idx_k.astype(np.int)) inputs[Properties.neighbor_offsets_j] = torch.LongTensor( offset_idx_j.astype(np.int) ) inputs[Properties.neighbor_offsets_k] = torch.LongTensor( offset_idx_k.astype(np.int) ) return inputs
def _convert_atoms( atoms, environment_provider=SimpleEnvironmentProvider(), collect_triples=False, centering_function=None, output=None, ): """ Helper function to convert ASE atoms object to SchNetPack input format. Args: atoms (ase.Atoms): Atoms object of molecule environment_provider (callable): Neighbor list provider. collect_triples (bool, optional): Set to True if angular features are needed. centering_function (callable or None): Function for calculating center of molecule (center of mass/geometry/...). Center will be subtracted from positions. output (dict): Destination for converted atoms, if not None Returns: dict of torch.Tensor: Properties including neighbor lists and masks reformated into SchNetPack input format. """ if output is None: inputs = {} else: inputs = output # Elemental composition inputs[Properties.Z] = atoms.numbers.astype(np.int) positions = atoms.positions.astype(np.float32) if centering_function: positions -= centering_function(atoms) inputs[Properties.R] = positions # get atom environment nbh_idx, offsets = environment_provider.get_environment(atoms) # Get neighbors and neighbor mask inputs[Properties.neighbors] = nbh_idx.astype(np.int) # Get cells inputs[Properties.cell] = np.array(atoms.cell.array, dtype=np.float32) inputs[Properties.cell_offset] = offsets.astype(np.float32) # If requested get neighbor lists for triples if collect_triples: nbh_idx_j, nbh_idx_k, offset_idx_j, offset_idx_k = collect_atom_triples( nbh_idx) inputs[Properties.neighbor_pairs_j] = nbh_idx_j.astype(np.int) inputs[Properties.neighbor_pairs_k] = nbh_idx_k.astype(np.int) inputs[Properties.neighbor_offsets_j] = offset_idx_j.astype(np.int) inputs[Properties.neighbor_offsets_k] = offset_idx_k.astype(np.int) return inputs
def convert_atoms(self, atoms): """ Args: atoms (ase.Atoms): Atoms object of molecule Returns: dict of torch.Tensor: Properties including neighbor lists and masks reformated into SchNetPack input format. """ inputs = {} idx = 0 # Elemental composition inputs[Structure.Z] = torch.LongTensor(atoms.numbers.astype(np.int)) inputs[Structure.atom_mask] = torch.ones_like( inputs[Structure.Z]).float() # Set positions positions = atoms.positions.astype(np.float32) inputs[Structure.R] = torch.FloatTensor(positions) # get atom environment nbh_idx, offsets = self.environment_provider.get_environment( idx, atoms) # Get neighbors and neighbor mask mask = torch.FloatTensor(nbh_idx) >= 0 inputs[Structure.neighbor_mask] = mask.float() inputs[Structure.neighbors] = torch.LongTensor(nbh_idx.astype( np.int)) * mask.long() # Get cells inputs[Structure.cell] = torch.FloatTensor( atoms.cell.astype(np.float32)) inputs[Structure.cell_offset] = torch.FloatTensor( offsets.astype(np.float32)) # Set index inputs['_idx'] = torch.LongTensor(np.array([idx], dtype=np.int)) # If requested get masks and neighbor lists for neighbor pairs if self.collect_triples is not None: nbh_idx_j, nbh_idx_k = collect_atom_triples(nbh_idx) inputs[Structure.neighbor_pairs_j] = torch.LongTensor( nbh_idx_j.astype(np.int)) inputs[Structure.neighbor_pairs_k] = torch.LongTensor( nbh_idx_k.astype(np.int)) inputs[Structure.neighbor_pairs_mask] = torch.ones_like( inputs[Structure.neighbor_pairs_j]).float() # Add batch dimension and move to CPU/GPU for key, value in inputs.items(): inputs[key] = value.unsqueeze(0).to(self.device) return inputs
def _sharc2schnet(self, sharc_output): # Update internal structure with new Shark positions self.molecule.positions = np.array(sharc_output) schnet_inputs = dict() # Elemental composition schnet_inputs[Properties.Z] = torch.LongTensor( self.molecule.numbers.astype(np.int)) schnet_inputs[Properties.atom_mask] = torch.ones_like( schnet_inputs[Properties.Z]).float() # Set positions positions = self.molecule.positions.astype(np.float32) schnet_inputs[Properties.R] = torch.FloatTensor(positions) # get atom environment nbh_idx, offsets = self.environment_provider.get_environment( self.molecule) # Get neighbors and neighbor mask mask = torch.FloatTensor(nbh_idx) >= 0 schnet_inputs[Properties.neighbor_mask] = mask.float() schnet_inputs[Properties.neighbors] = torch.LongTensor( nbh_idx.astype(np.int)) * mask.long() # Get cells schnet_inputs[Properties.cell] = torch.FloatTensor( self.molecule.cell.astype(np.float32)) schnet_inputs[Properties.cell_offset] = torch.FloatTensor( offsets.astype(np.float32)) # If requested get masks and neighbor lists for neighbor pairs if self.collect_triples is not None: nbh_idx_j, nbh_idx_k, offset_idx_j, offset_idx_k = collect_atom_triples( nbh_idx) schnet_inputs[Properties.neighbor_pairs_j] = torch.LongTensor( nbh_idx_j.astype(np.int)) schnet_inputs[Properties.neighbor_pairs_k] = torch.LongTensor( nbh_idx_k.astype(np.int)) schnet_inputs[Properties.neighbor_pairs_mask] = torch.ones_like( schnet_inputs[Properties.neighbor_pairs_j]).float() # Add batch dimension and move to CPU/GPU for key, value in schnet_inputs.items(): schnet_inputs[key] = value.unsqueeze(0).to(self.device) return schnet_inputs
def _convert_atoms( atoms, environment_provider=SimpleEnvironmentProvider(), collect_triples=False, centering_function=get_center_of_mass, output=None, res_list=None, ): """ Helper function to convert ASE atoms object to SchNetPack input format. Args: atoms (ase.Atoms): Atoms object of molecule environment_provider (callable): Neighbor list provider. collect_triples (bool, optional): Set to True if angular features are needed. centering_function (callable or None): Function for calculating center of molecule (center of mass/geometry/...). Center will be subtracted from positions. output (dict): Destination for converted atoms, if not None Returns: dict of torch.Tensor: Properties including neighbor lists and masks reformated into SchNetPack input format. """ if output is None: inputs = {} else: inputs = output # Elemental composition cell = np.array(atoms.cell.array, dtype=np.float32) # get cell array inputs[Properties.Z] = torch.LongTensor(atoms.numbers.astype(np.int)) positions = atoms.positions.astype(np.float32) if centering_function: positions -= centering_function(atoms) inputs[Properties.R] = torch.FloatTensor(positions) inputs[Properties.cell] = torch.FloatTensor(cell) if "AP" not in type(environment_provider).__name__: # get atom environment nbh_idx, offsets = environment_provider.get_environment(atoms) # Get neighbors and neighbor mask inputs[Properties.neighbors] = torch.LongTensor(nbh_idx.astype(np.int)) # Get cells inputs[Properties.cell_offset] = torch.FloatTensor( offsets.astype(np.float32)) # If requested get neighbor lists for triples if collect_triples: # Construct possible permutations nbh_idx_j, nbh_idx_k, offset_idx_j, offset_idx_k = collect_atom_triples( nbh_idx) inputs[Properties.neighbor_pairs_j] = torch.LongTensor( nbh_idx_j.astype(np.int)) inputs[Properties.neighbor_pairs_k] = torch.LongTensor( nbh_idx_k.astype(np.int)) inputs[Properties.neighbor_offsets_j] = torch.LongTensor( offset_idx_j.astype(np.int)) inputs[Properties.neighbor_offsets_k] = torch.LongTensor( offset_idx_k.astype(np.int)) elif type(environment_provider).__name__ == "APModEnvironmentProvider": # get atom environment nbh_idx, offsets = environment_provider.get_environment(atoms) # Get neighbors and neighbor mask inputs[Properties.neighbors] = torch.LongTensor(nbh_idx.astype(np.int)) # Get cells inputs[Properties.cell_offset] = torch.FloatTensor( offsets.astype(np.float32)) ZB = atoms.numbers.astype(np.int) natoms, nneigh = nbh_idx.shape ZB = np.tile(ZB[np.newaxis], (natoms, 1)) ZB = ZB[~np.eye(natoms, dtype=np.bool)].reshape(natoms, natoms - 1) inputs["ZB"] = torch.LongTensor(ZB) # If requested get neighbor lists for triples if collect_triples: # Construct possible permutations nbh_idx_j = np.tile(nbh_idx, nneigh) nbh_idx_k = np.repeat(nbh_idx, nneigh).reshape((natoms, -1)) nbh_idx_j_tmp = nbh_idx_j[nbh_idx_j != nbh_idx_k].reshape( natoms, nneigh * (nneigh - 1)) nbh_idx_k_tmp = nbh_idx_k[nbh_idx_j != nbh_idx_k].reshape( natoms, nneigh * (nneigh - 1)) # Keep track of periodic images offset_idx = np.tile(np.arange(nneigh), (natoms, 1)) # Construct indices for pairs of offsets offset_idx_j = np.tile(offset_idx, nneigh) offset_idx_k = np.repeat(offset_idx, nneigh).reshape((natoms, -1)) offset_idx_j_tmp = offset_idx_j[nbh_idx_j != nbh_idx_k].reshape( natoms, nneigh * (nneigh - 1)) offset_idx_k_tmp = offset_idx_k[nbh_idx_j != nbh_idx_k].reshape( natoms, nneigh * (nneigh - 1)) inputs[Properties.neighbor_pairs_j] = torch.LongTensor( nbh_idx_j_tmp.astype(np.int)) inputs[Properties.neighbor_pairs_k] = torch.LongTensor( nbh_idx_k_tmp.astype(np.int)) inputs[Properties.neighbor_offsets_j] = torch.LongTensor( offset_idx_j_tmp.astype(np.int)) inputs[Properties.neighbor_offsets_k] = torch.LongTensor( offset_idx_k_tmp.astype(np.int)) """ # If requested get neighbor lists for triples if collect_triples: # Construct possible permutations nbh_idx_j, nbh_idx_k, offset_idx_j, offset_idx_k = collect_atom_triples(nbh_idx) inputs[Properties.neighbor_pairs_j] = torch.LongTensor(nbh_idx_j.astype(np.int)) inputs[Properties.neighbor_pairs_k] = torch.LongTensor(nbh_idx_k.astype(np.int)) inputs[Properties.neighbor_offsets_j] = torch.LongTensor( offset_idx_j.astype(np.int) ) inputs[Properties.neighbor_offsets_k] = torch.LongTensor( offset_idx_k.astype(np.int) ) """ elif type(environment_provider).__name__ == "APModPBCEnvironmentProvider": # get atom environment nbh_idx, offsets = environment_provider.get_environment(atoms) # Get neighbors and neighbor mask inputs[Properties.neighbors] = torch.LongTensor(nbh_idx.astype(np.int)) neighbor_test = inputs[Properties.neighbors].unsqueeze(0) pos_test = inputs[Properties.R].unsqueeze(0) cell_test = inputs[Properties.cell].unsqueeze(0) # Get cells inputs[Properties.cell_offset] = torch.FloatTensor( offsets.astype(np.float32)) offset_test = torch.FloatTensor(offsets.astype( np.float32)).unsqueeze(0) distances, displacement = atom_distances(pos_test, neighbor_test, cell_test, offset_test, return_vecs=True) displacement = displacement.squeeze(0) #only works for orthorhombic box box_shift = -torch.round(displacement / cell_test[0, 0, 0]) inputs[Properties.cell_offset] = torch.FloatTensor(box_shift) ZB = atoms.numbers.astype(np.int) natoms, nneigh = nbh_idx.shape ZB = np.tile(ZB[np.newaxis], (natoms, 1)) ZB = ZB[~np.eye(natoms, dtype=np.bool)].reshape(natoms, natoms - 1) inputs["ZB"] = torch.LongTensor(ZB) # If requested get neighbor lists for triples if collect_triples: # Construct possible permutations nbh_idx_j = np.tile(nbh_idx, nneigh) nbh_idx_k = np.repeat(nbh_idx, nneigh).reshape((natoms, -1)) nbh_idx_j_tmp = nbh_idx_j[nbh_idx_j != nbh_idx_k].reshape( natoms, nneigh * (nneigh - 1)) nbh_idx_k_tmp = nbh_idx_k[nbh_idx_j != nbh_idx_k].reshape( natoms, nneigh * (nneigh - 1)) # Keep track of periodic images offset_idx = np.tile(np.arange(nneigh), (natoms, 1)) # Construct indices for pairs of offsets offset_idx_j = np.tile(offset_idx, nneigh) offset_idx_k = np.repeat(offset_idx, nneigh).reshape((natoms, -1)) offset_idx_j_tmp = offset_idx_j[nbh_idx_j != nbh_idx_k].reshape( natoms, nneigh * (nneigh - 1)) offset_idx_k_tmp = offset_idx_k[nbh_idx_j != nbh_idx_k].reshape( natoms, nneigh * (nneigh - 1)) inputs[Properties.neighbor_pairs_j] = torch.LongTensor( nbh_idx_j_tmp.astype(np.int)) inputs[Properties.neighbor_pairs_k] = torch.LongTensor( nbh_idx_k_tmp.astype(np.int)) inputs[Properties.neighbor_offsets_j] = torch.LongTensor( offset_idx_j_tmp.astype(np.int)) inputs[Properties.neighbor_offsets_k] = torch.LongTensor( offset_idx_k_tmp.astype(np.int)) elif type(environment_provider).__name__ == "APNetPBCEnvironmentProvider": if res_list is not None: monA = len(res_list[0]) monB = len(res_list[1]) inputs['ZA'] = torch.LongTensor(atoms.numbers[0:monA].astype( np.int)) inputs['ZB'] = torch.LongTensor(atoms.numbers[monA:monA + monB].astype(np.int)) inputs['ZA'], inputs['ZB'] = inputs['ZA'].long(), inputs['ZB'].long() # get atom environment nbh_idx_intra, offset_intra, nbh_idx_inter, offsets_inter = environment_provider.get_environment( atoms, inputs) # Get neighbors and neighbor mask inputs[Properties.neighbor_inter] = torch.LongTensor( nbh_idx_inter.astype(np.int)) mask = inputs[Properties.neighbor_inter] >= 0 inputs[Properties.neighbor_inter_mask] = mask.float() inputs[Properties.neighbor_inter] = ( inputs[Properties.neighbor_inter] * inputs[Properties.neighbor_inter_mask].long()) neighbor_test = inputs[Properties.neighbor_inter].unsqueeze(0) pos_test = inputs[Properties.R].unsqueeze(0) cell_test = inputs[Properties.cell].unsqueeze(0) # Get cells offset_test = torch.FloatTensor(offsets_inter.astype( np.float32)).unsqueeze(0) distances, displacement = atom_distances(pos_test, neighbor_test, cell_test, offset_test, return_vecs=True) displacement = displacement.squeeze(0) #only works for orthorhombic box box_shift = -torch.round(displacement / cell_test[0, 0, 0]) distances = atom_distances(pos_test, neighbor_test, cell_test, box_shift.unsqueeze(0)) inputs[Properties.neighbor_offset_inter] = torch.FloatTensor(box_shift) natoms, nneigh = nbh_idx_inter.shape nbh_idx_k = np.tile(nbh_idx_intra, nneigh) nbh_idx_j = np.repeat(nbh_idx_inter, nneigh).reshape((natoms, -1)) offset_idx = np.tile(np.arange(nneigh), (natoms, 1)) offset_idx_k = np.tile(offset_idx, nneigh) offset_idx_j = np.repeat(offset_idx, nneigh).reshape((natoms, -1)) inputs[Properties.neighbor_pairs_j] = torch.LongTensor( nbh_idx_j.astype(np.int)) inputs[Properties.neighbor_pairs_k] = torch.LongTensor( nbh_idx_k.astype(np.int)) inputs[Properties.neighbor_offsets_j] = torch.LongTensor( offset_idx_j.astype(np.int)) inputs[Properties.neighbor_offsets_k] = torch.LongTensor( offset_idx_k.astype(np.int)) mask_triples = np.ones_like( inputs[Properties.neighbor_pairs_j].numpy()) mask_triples[inputs[Properties.neighbor_pairs_j].numpy() < 0] = 0 mask_triples[inputs[Properties.neighbor_pairs_k].numpy() < 0] = 0 neighbor_test = torch.LongTensor(nbh_idx_intra.astype( np.int)).unsqueeze(0) pos_test = inputs[Properties.R].unsqueeze(0) cell_test = inputs[Properties.cell].unsqueeze(0) # Get cells offset_test = torch.FloatTensor(offset_intra.astype( np.float32)).unsqueeze(0) distances, displacement = atom_distances(pos_test, neighbor_test, cell_test, offset_test, return_vecs=True) displacement = displacement.squeeze(0) #only works for orthorhombic box box_shift = -torch.round(displacement / cell_test[0, 0, 0]) inputs[Properties.cell_offset_intra] = torch.FloatTensor(box_shift) mask_self = np.repeat(np.arange(0, nbh_idx_k.shape[0]), nbh_idx_k.shape[1]).reshape( nbh_idx_k.shape[0], nbh_idx_k.shape[1]) mask_triples[mask_self == nbh_idx_k] = 0 inputs[Properties.neighbor_pairs_mask] = torch.LongTensor( mask_triples.astype(np.float)) mask_self = np.repeat(np.arange(0, nbh_idx_intra.shape[0]), nbh_idx_intra.shape[1]).reshape( nbh_idx_intra.shape[0], nbh_idx_intra.shape[1]) neighborhood_idx = nbh_idx_intra[mask_self != nbh_idx_intra].reshape( nbh_idx_intra.shape[0], nbh_idx_intra.shape[1] - 1) inputs[Properties.neighbors] = torch.LongTensor( neighborhood_idx.astype(np.int)) box_shift = box_shift[mask_self != nbh_idx_intra, :].reshape( nbh_idx_intra.shape[0], nbh_idx_intra.shape[1] - 1, 3) inputs[Properties.cell_offset] = torch.FloatTensor(box_shift) else: if res_list is not None: monA = len(res_list[0]) monB = len(res_list[1]) inputs['ZA'] = torch.LongTensor(atoms.numbers[0:monA].astype( np.int)) inputs['ZB'] = torch.LongTensor(atoms.numbers[monA:monA + monB].astype(np.int)) inputs['ZA'], inputs['ZB'] = inputs['ZA'].long(), inputs['ZB'].long() # get atom environment nbh_idx_intra, offset_intra, nbh_idx_inter, offsets_inter = environment_provider.get_environment( atoms, inputs) nbh_idx_test, offset_test = SimpleEnvironmentProvider( ).get_environment(atoms) #print(nbh_idx_test[0, :]) #sys.exit() # Get neighbors and neighbor mask inputs[Properties.neighbor_inter] = torch.LongTensor( nbh_idx_inter.astype(np.int)) mask = inputs[Properties.neighbor_inter] >= 0 inputs[Properties.neighbor_inter_mask] = mask.float() inputs[Properties.neighbor_inter] = ( inputs[Properties.neighbor_inter] * inputs[Properties.neighbor_inter_mask].long()) # Get cells inputs[Properties.neighbor_offset_inter] = torch.FloatTensor( offsets_inter.astype(np.float32)) natoms, nneigh = nbh_idx_inter.shape nbh_idx_k = np.tile(nbh_idx_intra, nneigh) nbh_idx_j = np.repeat(nbh_idx_inter, nneigh).reshape((natoms, -1)) offset_idx = np.tile(np.arange(nneigh), (natoms, 1)) offset_idx_k = np.tile(offset_idx, nneigh) offset_idx_j = np.repeat(offset_idx, nneigh).reshape((natoms, -1)) inputs[Properties.neighbor_pairs_j] = torch.LongTensor( nbh_idx_j.astype(np.int)) inputs[Properties.neighbor_pairs_k] = torch.LongTensor( nbh_idx_k.astype(np.int)) inputs[Properties.neighbor_offsets_j] = torch.LongTensor( offset_idx_j.astype(np.int)) inputs[Properties.neighbor_offsets_k] = torch.LongTensor( offset_idx_k.astype(np.int)) mask_triples = np.ones_like( inputs[Properties.neighbor_pairs_j].numpy()) mask_triples[inputs[Properties.neighbor_pairs_j].numpy() < 0] = 0 mask_triples[inputs[Properties.neighbor_pairs_k].numpy() < 0] = 0 mask_self = np.repeat(np.arange(0, nbh_idx_k.shape[0]), nbh_idx_k.shape[1]).reshape( nbh_idx_k.shape[0], nbh_idx_k.shape[1]) mask_triples[mask_self == nbh_idx_k] = 0 inputs[Properties.neighbor_pairs_mask] = torch.LongTensor( mask_triples.astype(np.float)) mask_self = np.repeat(np.arange(0, nbh_idx_intra.shape[0]), nbh_idx_intra.shape[1]).reshape( nbh_idx_intra.shape[0], nbh_idx_intra.shape[1]) neighborhood_idx = nbh_idx_intra[mask_self != nbh_idx_intra].reshape( nbh_idx_intra.shape[0], nbh_idx_intra.shape[1] - 1) inputs[Properties.neighbors] = torch.LongTensor( neighborhood_idx.astype(np.int)) inputs[Properties.cell_offset_intra] = torch.FloatTensor( offset_intra.astype(np.float32)) offset_intra = offset_intra[mask_self != nbh_idx_intra, :].reshape( nbh_idx_intra.shape[0], nbh_idx_intra.shape[1] - 1, 3) inputs[Properties.cell_offset] = torch.FloatTensor( offset_intra.astype(np.float32)) return inputs
def _convert_atoms( atoms, environment_provider=SimpleEnvironmentProvider(), collect_triples=False, centering_function=None, output=None, deleteIntraatomicInteraction=False, intraAtomicIdent=None, ): """ Helper function to convert ASE atoms object to SchNetPack input format. Args: atoms (ase.Atoms): Atoms object of molecule environment_provider (callable): Neighbor list provider. collect_triples (bool, optional): Set to True if angular features are needed. centering_function (callable or None): Function for calculating center of molecule (center of mass/geometry/...). Center will be subtracted from positions. output (dict): Destination for converted atoms, if not None Returns: dict of torch.Tensor: Properties including neighbor lists and masks reformated into SchNetPack input format. """ if output is None: inputs = {} else: inputs = output # Elemental composition cell = np.array(atoms.cell.array, dtype=np.float32) # get cell array inputs[Properties.Z] = torch.LongTensor(atoms.numbers.astype(np.int)) positions = atoms.positions.astype(np.float32) if centering_function: positions -= centering_function(atoms) inputs[Properties.R] = torch.FloatTensor(positions) inputs[Properties.cell] = torch.FloatTensor(cell) # get atom environment nbh_idx, offsets = environment_provider.get_environment(atoms) if deleteIntraatomicInteraction: if intraAtomicIdent is None: intraAtomicIdent = inputs['props'][:, :, 7:8] nl = nbh_idx.astype(np.int).copy() ligand_indicator = intraAtomicIdent for i in range(len(nl)): for j in range(len(nl[i])): atom_1 = i atom_2 = nl[i][j] if ligand_indicator[atom_1] != ligand_indicator[atom_2]: nl[i][j] = -1 inputs[Properties.neighbors] = torch.LongTensor(nl.astype(np.int)) else: inputs[Properties.neighbors] = torch.LongTensor(nbh_idx.astype(np.int)) # Get neighbors and neighbor mask # Get cells inputs[Properties.cell] = torch.FloatTensor(cell) inputs[Properties.cell_offset] = torch.FloatTensor( offsets.astype(np.float32)) # If requested get neighbor lists for triples if collect_triples: nbh_idx_j, nbh_idx_k, offset_idx_j, offset_idx_k = collect_atom_triples( nbh_idx) inputs[Properties.neighbor_pairs_j] = torch.LongTensor( nbh_idx_j.astype(np.int)) inputs[Properties.neighbor_pairs_k] = torch.LongTensor( nbh_idx_k.astype(np.int)) inputs[Properties.neighbor_offsets_j] = torch.LongTensor( offset_idx_j.astype(np.int)) inputs[Properties.neighbor_offsets_k] = torch.LongTensor( offset_idx_k.astype(np.int)) return inputs