def test_mol_basis_data(self, basis_data): """Test that correct basis set parameters are generated for a given molecule represented as a list of atoms.""" basis, symbols, n_ref, params_ref = basis_data n_basis, params = mol_basis_data(basis, symbols) assert n_basis == n_ref assert np.allclose(params, params_ref)
def __init__( self, symbols, coordinates, charge=0, mult=1, basis_name="sto-3g", l=None, alpha=None, coeff=None, ): if basis_name not in ["sto-3g", "STO-3G"]: raise ValueError("Currently, the only supported basis set is 'sto-3g'.") if set(symbols) - set(atomic_numbers): raise ValueError(f"Atoms in {set(symbols) - set(atomic_numbers)} are not supported.") self.symbols = symbols self.coordinates = coordinates self.charge = charge self.mult = mult self.basis_name = basis_name self.n_basis, self.basis_data = mol_basis_data(self.basis_name, self.symbols) if l is None: l = [i[0] for i in self.basis_data] if alpha is None: alpha = [pnp.array(i[1], requires_grad=False) for i in self.basis_data] if coeff is None: coeff = [pnp.array(i[2], requires_grad=False) for i in self.basis_data] r = list( itertools.chain( *[[self.coordinates[i]] * self.n_basis[i] for i in range(len(self.n_basis))] ) ) self.l = l self.alpha = alpha self.coeff = coeff self.r = r self.basis_set = [ BasisFunction(self.l[i], self.alpha[i], self.coeff[i], self.r[i]) for i in range(len(l)) ] self.n_orbitals = len(self.l) self.nuclear_charges = [atomic_numbers[s] for s in self.symbols] self.n_electrons = sum(np.array(self.nuclear_charges)) - self.charge