def test_generator(self): """Test the feature generation.""" atoms = molecule('HCOOH') atoms.center(vacuum=5) radii = [covalent_radii[z] for z in atoms.numbers] atoms.connectivity = ase_connectivity(atoms, radii) images = [atoms] gen = FeatureGenerator() features = gen.return_vec(images, [gen.get_autocorrelation]) np.testing.assert_allclose(features, truth)
def images_connectivity(images): """Return a list of atoms objects imported from an ase database. Parameters ---------- fname : str path/filename of ase database. selection : list search filters to limit the import. """ for atoms in tqdm(images): if not hasattr(atoms, 'connectivity'): radii = [default_catlearn_radius(z) for z in atoms.numbers] atoms.connectivity = ase_connectivity(atoms, radii) return images
def images_connectivity(images, check_cn_max=False): """Return a list of atoms objects imported from an ase database. Parameters ---------- fname : str path/filename of ase database. selection : list search filters to limit the import. """ for atoms in tqdm(images): if not hasattr(atoms, 'connectivity'): radii = [default_catlearn_radius(z) for z in atoms.numbers] atoms.connectivity = ase_connectivity(atoms, cutoffs=radii) if check_cn_max: n_connections = np.sum(atoms.connectivity, axis=0) if max(n_connections) > 12: msg = 'atom has more than 12 connections.' if 'id' in atoms.info: msg += str(atoms.info['id']) warnings.warn(msg) return images