class TestGroundStateFinderTwoActiveSublattices(unittest.TestCase): """Container for test of the class functionality for a system with two active sublattices.""" def __init__(self, *args, **kwargs): super(TestGroundStateFinderTwoActiveSublattices, self).__init__(*args, **kwargs) a = 4.0 self.chemical_symbols = [['Au', 'Pd'], ['Li', 'Na']] self.cutoffs = [3.0] structure_prim = bulk(self.chemical_symbols[0][0], a=a) structure_prim.append( Atom(self.chemical_symbols[1][0], position=(a / 2, a / 2, a / 2))) structure_prim.wrap() self.structure_prim = structure_prim self.cs = ClusterSpace(self.structure_prim, self.cutoffs, self.chemical_symbols) parameters = [0.1, -0.45, 0.333, 2, -1.42, 0.98] self.ce = ClusterExpansion(self.cs, parameters) self.all_possible_structures = [] self.supercell = self.structure_prim.repeat(2) self.sl1_indices = [ s for s, sym in enumerate(self.supercell.get_chemical_symbols()) if sym == self.chemical_symbols[0][0] ] self.sl2_indices = [ s for s, sym in enumerate(self.supercell.get_chemical_symbols()) if sym == self.chemical_symbols[1][0] ] for i in self.sl1_indices: for j in self.sl2_indices: structure = self.supercell.copy() structure.symbols[i] = self.chemical_symbols[0][1] structure.symbols[j] = self.chemical_symbols[1][1] self.all_possible_structures.append(structure) def shortDescription(self): """Silences unittest from printing the docstrings in test cases.""" return None def setUp(self): """Setup before each test.""" self.gsf = icet.tools.ground_state_finder.GroundStateFinder( self.ce, self.supercell, verbose=False) def test_init(self): """Tests that initialization of tested class work.""" # initialize from ClusterExpansion instance gsf = icet.tools.ground_state_finder.GroundStateFinder(self.ce, self.supercell, verbose=False) self.assertIsInstance(gsf, icet.tools.ground_state_finder.GroundStateFinder) def test_get_ground_state(self): """Tests get_ground_state functionality.""" target_val = min([ self.ce.predict(structure) for structure in self.all_possible_structures ]) # Provide counts for the first/first species species_count = { self.chemical_symbols[0][0]: len(self.sl1_indices) - 1, self.chemical_symbols[1][0]: len(self.sl2_indices) - 1 } ground_state = self.gsf.get_ground_state(species_count=species_count) predicted_species00 = self.ce.predict(ground_state) self.assertEqual(predicted_species00, target_val) # Provide counts for the first/second species species_count = { self.chemical_symbols[0][0]: len(self.sl1_indices) - 1, self.chemical_symbols[1][1]: 1 } ground_state = self.gsf.get_ground_state(species_count=species_count) predicted_species01 = self.ce.predict(ground_state) self.assertEqual(predicted_species01, predicted_species00) # Provide counts for second/second species species_count = { self.chemical_symbols[0][1]: 1, self.chemical_symbols[1][1]: 1 } ground_state = self.gsf.get_ground_state(species_count=species_count) predicted_species11 = self.ce.predict(ground_state) self.assertEqual(predicted_species11, predicted_species01) def _test_ground_state_cluster_vectors_in_database(self, db_name): """Tests get_ground_state functionality by comparing the cluster vectors for the structures in the databases.""" filename = inspect.getframeinfo(inspect.currentframe()).filename path = os.path.dirname(os.path.abspath(filename)) db = db_connect(os.path.join(path, db_name)) # Select the structure set with the lowest pairwise correlations selections = ['id={}'.format(i) for i in [7, 13, 15, 62, 76]] for selection in selections: row = db.get(selection) structure = row.toatoms() target_cluster_vector = self.cs.get_cluster_vector(structure) species_count = { self.chemical_symbols[0][0]: structure.get_chemical_symbols().count( self.chemical_symbols[0][0]), self.chemical_symbols[1][0]: structure.get_chemical_symbols().count( self.chemical_symbols[1][0]) } ground_state = self.gsf.get_ground_state( species_count=species_count) gs_cluster_vector = self.cs.get_cluster_vector(ground_state) mean_diff = np.mean(abs(target_cluster_vector - gs_cluster_vector)) self.assertLess(mean_diff, 1e-8) def test_ground_state_cluster_vectors(self): """Tests get_ground_state functionality by comparing the cluster vectors for ground states obtained from simulated annealing.""" self._test_ground_state_cluster_vectors_in_database( '../../../structure_databases/annealing_ground_states.db') def test_get_ground_state_fails_for_faulty_species_to_count(self): """Tests that get_ground_state fails if species_to_count is faulty.""" # Check that get_ground_state fails if counts are provided for a both species on one # of the active sublattices species_count = { self.chemical_symbols[0][0]: len(self.sl1_indices) - 1, self.chemical_symbols[0][1]: 1, self.chemical_symbols[1][1]: 1 } with self.assertRaises(ValueError) as cm: self.gsf.get_ground_state(species_count=species_count) self.assertTrue( 'Provide counts for at most one of the species on each active sublattice ' '({}), not {}!'.format(self.gsf._active_species, list(species_count.keys())) in str( cm.exception)) # Check that get_ground_state fails if the count exceeds the number sites on the first # sublattice faulty_species = self.chemical_symbols[0][1] faulty_count = len(self.supercell) species_count = { faulty_species: faulty_count, self.chemical_symbols[1][1]: 1 } n_active_sites = len([ sym for sym in self.supercell.get_chemical_symbols() if sym == self.chemical_symbols[0][0] ]) with self.assertRaises(ValueError) as cm: self.gsf.get_ground_state(species_count=species_count) self.assertTrue( 'The count for species {} ({}) must be a positive integer and cannot ' 'exceed the number of sites on the active sublattice ' '({})'.format(faulty_species, faulty_count, n_active_sites) in str( cm.exception)) # Check that get_ground_state fails if the count exceeds the number sites on the second # sublattice faulty_species = self.chemical_symbols[1][1] faulty_count = len(self.supercell) species_count = { faulty_species: faulty_count, self.chemical_symbols[0][1]: 1 } n_active_sites = len([ sym for sym in self.supercell.get_chemical_symbols() if sym == self.chemical_symbols[1][0] ]) with self.assertRaises(ValueError) as cm: self.gsf.get_ground_state(species_count=species_count) self.assertTrue( 'The count for species {} ({}) must be a positive integer and cannot ' 'exceed the number of sites on the active sublattice ' '({})'.format(faulty_species, faulty_count, n_active_sites) in str( cm.exception)) def test_get_ground_state_passes_for_partial_species_to_count(self): # Check that get_ground_state passes if a single count is provided species_count = {self.chemical_symbols[0][1]: 1} self.gsf.get_ground_state(species_count=species_count) # Check that get_ground_state passes no counts are provided gs = self.gsf.get_ground_state() self.assertEqual(gs.get_chemical_formula(), "Au8Li8")
class TestGroundStateFinderTriplets(unittest.TestCase): """Container for test of the class functionality for a system with triplets.""" def __init__(self, *args, **kwargs): super(TestGroundStateFinderTriplets, self).__init__(*args, **kwargs) self.chemical_symbols = ['Au', 'Pd'] self.cutoffs = [3.0, 3.0] structure_prim = fcc111(self.chemical_symbols[0], a=4.0, size=(1, 1, 6), vacuum=10, periodic=True) structure_prim.wrap() self.structure_prim = structure_prim self.cs = ClusterSpace(self.structure_prim, self.cutoffs, self.chemical_symbols) parameters = [0.0] * 4 + [0.1] * 6 + [-0.02] * 11 self.ce = ClusterExpansion(self.cs, parameters) self.all_possible_structures = [] self.supercell = self.structure_prim.repeat((2, 2, 1)) for i in range(len(self.supercell)): structure = self.supercell.copy() structure.symbols[i] = self.chemical_symbols[1] self.all_possible_structures.append(structure) def shortDescription(self): """Silences unittest from printing the docstrings in test cases.""" return None def setUp(self): """Setup before each test.""" self.gsf = icet.tools.ground_state_finder.GroundStateFinder( self.ce, self.supercell, verbose=False) def test_init(self): """Tests that initialization of tested class work.""" # initialize from ClusterExpansion instance gsf = icet.tools.ground_state_finder.GroundStateFinder(self.ce, self.supercell, verbose=False) self.assertIsInstance(gsf, icet.tools.ground_state_finder.GroundStateFinder) def test_get_ground_state(self): """Tests get_ground_state functionality.""" target_val = min([ self.ce.predict(structure) for structure in self.all_possible_structures ]) # Provide counts for first species species_count = {self.chemical_symbols[0]: len(self.supercell) - 1} ground_state = self.gsf.get_ground_state(species_count=species_count, threads=1) predicted_species0 = self.ce.predict(ground_state) self.assertEqual(predicted_species0, target_val) # Provide counts for second species species_count = {self.chemical_symbols[1]: 1} ground_state = self.gsf.get_ground_state(species_count=species_count, threads=1) predicted_species1 = self.ce.predict(ground_state) self.assertEqual(predicted_species0, predicted_species1) # Check that get_ground_state finds 50-50 mix when no counts are provided gsf = icet.tools.ground_state_finder.GroundStateFinder(self.ce, self.supercell, verbose=False) gs = gsf.get_ground_state(threads=1) self.assertEqual(gs.get_chemical_formula(), "Au12Pd12") # Ensure that an exception is raised when no solution is found gsf = icet.tools.ground_state_finder.GroundStateFinder( self.ce, self.supercell, solver_name='CBC', verbose=False) species_count = {self.chemical_symbols[1]: 1} with self.assertRaises(Exception) as cm: gsf.get_ground_state(species_count=species_count, max_seconds=0.0, threads=1) self.assertTrue('Optimization failed' in str(cm.exception))
class TestGroundStateFinderZeroParameter(unittest.TestCase): """Container for test of the class functionality for a system with a zero parameter.""" def __init__(self, *args, **kwargs): super(TestGroundStateFinderZeroParameter, self).__init__(*args, **kwargs) self.chemical_symbols = ['Ag', 'Au'] self.cutoffs = [4.3] self.structure_prim = bulk(self.chemical_symbols[1], a=4.0) self.cs = ClusterSpace(self.structure_prim, self.cutoffs, self.chemical_symbols) nonzero_ce = ClusterExpansion(self.cs, [0, 0, 0.1, -0.02]) lolg = LocalOrbitListGenerator( self.cs.orbit_list, Structure.from_atoms(self.structure_prim), self.cs.fractional_position_tolerance) full_orbit_list = lolg.generate_full_orbit_list() binary_parameters_zero = transform_parameters(self.structure_prim, full_orbit_list, nonzero_ce.parameters) binary_parameters_zero[1] = 0 A = get_transformation_matrix(self.structure_prim, full_orbit_list) Ainv = np.linalg.inv(A) zero_parameters = np.dot(Ainv, binary_parameters_zero) self.ce = ClusterExpansion(self.cs, zero_parameters) self.all_possible_structures = [] self.supercell = self.structure_prim.repeat(2) for i in range(len(self.supercell)): structure = self.supercell.copy() structure.symbols[i] = self.chemical_symbols[0] self.all_possible_structures.append(structure) def shortDescription(self): """Silences unittest from printing the docstrings in test cases.""" return None def setUp(self): """Setup before each test.""" self.gsf = icet.tools.ground_state_finder.GroundStateFinder( self.ce, self.supercell, verbose=False) def test_init(self): """Tests that initialization of tested class work.""" # initialize from ClusterExpansion instance gsf = icet.tools.ground_state_finder.GroundStateFinder(self.ce, self.supercell, verbose=False) self.assertIsInstance(gsf, icet.tools.ground_state_finder.GroundStateFinder) def test_get_ground_state(self): """Tests get_ground_state functionality.""" target_val = min([ self.ce.predict(structure) for structure in self.all_possible_structures ]) # Provide counts for first species species_count = {self.chemical_symbols[0]: 1} ground_state = self.gsf.get_ground_state(species_count=species_count, threads=1) predicted_species0 = self.ce.predict(ground_state) self.assertEqual(predicted_species0, target_val) # Provide counts for second species species_count = {self.chemical_symbols[1]: len(self.supercell) - 1} ground_state = self.gsf.get_ground_state(species_count=species_count, threads=1) predicted_species1 = self.ce.predict(ground_state) self.assertEqual(predicted_species0, predicted_species1)
class TestGroundStateFinderInactiveSublatticeSameSpecies(unittest.TestCase): """Container for test of the class functionality for a system with an inactive sublattice occupied by a species found on the active sublattice.""" def __init__(self, *args, **kwargs): super(TestGroundStateFinderInactiveSublatticeSameSpecies, self).__init__(*args, **kwargs) self.chemical_symbols = [['Ag', 'Au'], ['Ag']] self.cutoffs = [4.3] a = 4.0 structure_prim = bulk(self.chemical_symbols[0][1], a=a) structure_prim.append( Atom(self.chemical_symbols[1][0], position=(a / 2, a / 2, a / 2))) self.structure_prim = structure_prim self.cs = ClusterSpace(self.structure_prim, self.cutoffs, self.chemical_symbols) self.ce = ClusterExpansion(self.cs, [0, 0, 0.1, -0.02]) self.all_possible_structures = [] self.supercell = self.structure_prim.repeat(2) sublattices = self.cs.get_sublattices(self.supercell) self.n_active_sites = [ len(subl.indices) for subl in sublattices.active_sublattices ] for i, sym in enumerate(self.supercell.get_chemical_symbols()): if sym not in self.chemical_symbols[0]: continue structure = self.supercell.copy() structure.symbols[i] = self.chemical_symbols[0][0] self.all_possible_structures.append(structure) def shortDescription(self): """Silences unittest from printing the docstrings in test cases.""" return None def setUp(self): """Setup before each test.""" self.gsf = icet.tools.ground_state_finder.GroundStateFinder( self.ce, self.supercell, verbose=False) def test_init(self): """Tests that initialization of tested class work.""" # initialize from ClusterExpansion instance gsf = icet.tools.ground_state_finder.GroundStateFinder(self.ce, self.supercell, verbose=False) self.assertIsInstance(gsf, icet.tools.ground_state_finder.GroundStateFinder) def test_get_ground_state(self): """Tests get_ground_state functionality.""" target_val = min([ self.ce.predict(structure) for structure in self.all_possible_structures ]) # Provide counts for first species species_count = {self.chemical_symbols[0][0]: 1} ground_state = self.gsf.get_ground_state(species_count=species_count, threads=1) predicted_species0 = self.ce.predict(ground_state) self.assertEqual(predicted_species0, target_val) species_count = { self.chemical_symbols[0][1]: self.n_active_sites[0] - 1 } ground_state = self.gsf.get_ground_state(species_count=species_count, threads=1) predicted_species1 = self.ce.predict(ground_state) self.assertEqual(predicted_species0, predicted_species1) def test_get_ground_state_fails_for_faulty_species_to_count(self): """Tests that get_ground_state fails if species_to_count is faulty.""" # Check that get_ground_state fails if counts are provided for multiple species species_count = { self.chemical_symbols[0][0]: 1, self.chemical_symbols[0][1]: self.n_active_sites[0] - 1 } with self.assertRaises(ValueError) as cm: self.gsf.get_ground_state(species_count=species_count) self.assertTrue( 'Provide counts for at most one of the species on each active sublattice ' '({}), not {}!'.format(self.gsf._active_species, list(species_count.keys())) in str( cm.exception)) # Check that get_ground_state fails if counts are provided for a # species not found on the active sublattice species_count = {'H': 1} with self.assertRaises(ValueError) as cm: self.gsf.get_ground_state(species_count=species_count) self.assertTrue( 'The species {} is not present on any of the active sublattices' ' ({})'.format( list(species_count.keys())[0], self.gsf._active_species) in str(cm.exception)) # Check that get_ground_state fails if the count exceeds the number sites on the active # sublattice faulty_species = self.chemical_symbols[0][0] faulty_count = len(self.supercell) species_count = {faulty_species: faulty_count} n_active_sites = len([ sym for sym in self.supercell.get_chemical_symbols() if sym == self.chemical_symbols[0][1] ]) with self.assertRaises(ValueError) as cm: self.gsf.get_ground_state(species_count=species_count) self.assertTrue( 'The count for species {} ({}) must be a positive integer and cannot ' 'exceed the number of sites on the active sublattice ' '({})'.format(faulty_species, faulty_count, n_active_sites) in str( cm.exception)) def test_create_cluster_maps(self): """Tests _create_cluster_maps functionality """ gsf = icet.tools.ground_state_finder.GroundStateFinder(self.ce, self.supercell, verbose=False) gsf._create_cluster_maps(self.structure_prim) # Test cluster to sites map target = [[0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]] self.assertEqual(target, gsf._cluster_to_sites_map) # Test cluster to orbit map target = [0, 1, 1, 1, 1, 1, 1, 2, 2, 2] self.assertEqual(target, gsf._cluster_to_orbit_map) # Test ncluster per orbit map target = [1, 1, 6, 3] self.assertEqual(target, gsf._nclusters_per_orbit)
class TestGroundStateFinder(unittest.TestCase): """Container for test of the class functionality.""" def __init__(self, *args, **kwargs): super(TestGroundStateFinder, self).__init__(*args, **kwargs) self.chemical_symbols = ['Ag', 'Au'] self.cutoffs = [4.3] self.structure_prim = bulk(self.chemical_symbols[1], a=4.0) self.cs = ClusterSpace(self.structure_prim, self.cutoffs, self.chemical_symbols) self.ce = ClusterExpansion(self.cs, [0, 0, 0.1, -0.02]) self.all_possible_structures = [] self.supercell = self.structure_prim.repeat(2) for i in range(len(self.supercell)): structure = self.supercell.copy() structure.symbols[i] = self.chemical_symbols[0] self.all_possible_structures.append(structure) def shortDescription(self): """Silences unittest from printing the docstrings in test cases.""" return None def setUp(self): """Setup before each test.""" self.gsf = icet.tools.ground_state_finder.GroundStateFinder( self.ce, self.supercell, verbose=False) def test_mip_import(self): """Tests the Python-MIP import statement""" # Test that an error is raised if Python-MIP is not installed with self.assertRaises(ImportError) as cm: with mock.patch.dict(sys.modules, {'mip': None}): importlib.reload(icet.tools.ground_state_finder) self.assertTrue( 'Python-MIP (https://python-mip.readthedocs.io/en/latest/) is required in ' 'order to use the ground state finder.' in str(cm.exception)) # Test that an error is raised if the Python-MIP version is not sufficiently recent with self.assertRaises(VersionConflict) as cm: with mock.patch('mip.constants.VERSION', '1.6.2'): importlib.reload(icet.tools.ground_state_finder) self.assertTrue( 'Python-MIP version 1.6.3 or later is required in order to use the ground ' 'state finder.' in str(cm.exception)) def test_init(self): """Tests that initialization of tested class work.""" # initialize from GroundStateFinder instance gsf = icet.tools.ground_state_finder.GroundStateFinder(self.ce, self.supercell, verbose=False) self.assertIsInstance(gsf, icet.tools.ground_state_finder.GroundStateFinder) def test_init_solver(self): """Tests that initialization of tested class work.""" # initialize from GroundStateFinder instance # Set the solver explicitely gsf = icet.tools.ground_state_finder.GroundStateFinder( self.ce, self.supercell, solver_name='CBC', verbose=False) self.assertEqual('CBC', gsf._model.solver_name.upper()) def test_init_fails_for_ternary_with_one_active_sublattice(self): """Tests that initialization fails for a ternary system with one active sublattice.""" chemical_symbols = ['Au', 'Ag', 'Pd'] cs = ClusterSpace(self.structure_prim, cutoffs=self.cutoffs, chemical_symbols=chemical_symbols) ce = ClusterExpansion(cs, [0.0] * len(cs)) with self.assertRaises(NotImplementedError) as cm: icet.tools.ground_state_finder.GroundStateFinder(ce, self.supercell, verbose=False) self.assertTrue( 'Currently, systems with more than two allowed species on any sublattice ' 'are not supported.' in str(cm.exception)) def test_optimization_status_property(self): """Tests the optimization_status property.""" # Check that the optimization_status is None initially self.assertIsNone(self.gsf.optimization_status) # Check that the optimization_status is OPTIMAL if a ground state is found species_count = {self.chemical_symbols[0]: 1} self.gsf.get_ground_state(species_count=species_count, threads=1) self.assertEqual(str(self.gsf.optimization_status), 'OptimizationStatus.OPTIMAL') def test_model_property(self): """Tests the model property.""" self.assertEqual(self.gsf.model.name, 'CE') def test_get_ground_state(self): """Tests get_ground_state functionality.""" target_val = min([ self.ce.predict(structure) for structure in self.all_possible_structures ]) # Provide counts for first species species_count = {self.chemical_symbols[0]: 1} ground_state = self.gsf.get_ground_state(species_count=species_count, threads=1) predicted_species0 = self.ce.predict(ground_state) self.assertEqual(predicted_species0, target_val) # Provide counts for second species species_count = {self.chemical_symbols[1]: len(self.supercell) - 1} ground_state = self.gsf.get_ground_state(species_count=species_count, threads=1) predicted_species1 = self.ce.predict(ground_state) self.assertEqual(predicted_species0, predicted_species1) # Set the maximum run time species_count = {self.chemical_symbols[0]: 1} ground_state = self.gsf.get_ground_state(species_count=species_count, max_seconds=0.5, threads=1) predicted_max_seconds = self.ce.predict(ground_state) self.assertGreaterEqual(predicted_max_seconds, predicted_species0) def test_get_ground_state_fails_for_faulty_species_to_count(self): """Tests that get_ground_state fails if species_to_count is faulty.""" # Check that get_ground_state fails if counts are provided for multiple # species species_count = { self.chemical_symbols[0]: 1, self.chemical_symbols[1]: len(self.supercell) - 1 } with self.assertRaises(ValueError) as cm: self.gsf.get_ground_state(species_count=species_count) self.assertTrue( 'Provide counts for at most one of the species on each active sublattice ' '({}), not {}!'.format(self.gsf._active_species, list(species_count.keys())) in str( cm.exception)) # Check that get_ground_state fails if counts are provided for a # species not found on the active sublattice species_count = {'H': 1} with self.assertRaises(ValueError) as cm: self.gsf.get_ground_state(species_count=species_count) self.assertTrue( 'The species {} is not present on any of the active sublattices' ' ({})'.format( list(species_count.keys())[0], self.gsf._active_species) in str(cm.exception)) # Check that get_ground_state fails if the count exceeds the number sites on the active # sublattice faulty_species = self.chemical_symbols[0] faulty_count = len(self.supercell) + 1 species_count = {faulty_species: faulty_count} with self.assertRaises(ValueError) as cm: self.gsf.get_ground_state(species_count=species_count) self.assertTrue( 'The count for species {} ({}) must be a positive integer and cannot ' 'exceed the number of sites on the active sublattice ' '({})'.format(faulty_species, faulty_count, len( self.supercell)) in str(cm.exception)) # Check that get_ground_state fails if the count is not a positive integer species = self.chemical_symbols[0] count = -1 species_count = {species: count} with self.assertRaises(ValueError) as cm: self.gsf.get_ground_state(species_count=species_count) self.assertTrue( 'The count for species {} ({}) must be a positive integer and cannot ' 'exceed the number of sites on the active sublattice ' '({})'.format(species, count, len(self.supercell)) in str( cm.exception)) def test_create_cluster_maps(self): """Tests _create_cluster_maps functionality """ gsf = icet.tools.ground_state_finder.GroundStateFinder(self.ce, self.supercell, verbose=False) gsf._create_cluster_maps(self.structure_prim) # Test cluster to sites map target = [[0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0]] self.assertEqual(target, gsf._cluster_to_sites_map) # Test cluster to orbit map target = [0, 1, 1, 1, 1, 1, 1, 2, 2, 2] self.assertEqual(target, gsf._cluster_to_orbit_map) # Test ncluster per orbit map target = [1, 1, 6, 3] self.assertEqual(target, gsf._nclusters_per_orbit)
class TestClusterExpansion(unittest.TestCase): """Container for tests of the class functionality.""" def __init__(self, *args, **kwargs): super(TestClusterExpansion, self).__init__(*args, **kwargs) self.structure = bulk('Au') self.cutoffs = [3.0] * 3 chemical_symbols = ['Au', 'Pd'] self.cs = ClusterSpace(self.structure, self.cutoffs, chemical_symbols) def shortDescription(self): """Silences unittest from printing the docstrings in test cases.""" return None def setUp(self): """Setup before each test.""" params_len = len(self.cs) self.parameters = np.arange(params_len) self.ce = ClusterExpansion(self.cs, self.parameters) def test_init(self): """Tests that initialization works.""" self.assertIsInstance(self.ce, ClusterExpansion) # test whether method raises Exception with self.assertRaises(ValueError) as context: ClusterExpansion(self.cs, [0.0]) self.assertTrue('cluster_space (5) and parameters (1) must have the' ' same length' in str(context.exception)) def test_predict(self): """Tests predict function.""" predicted_val = self.ce.predict(self.structure) self.assertEqual(predicted_val, 10.0) def test_property_orders(self): """Tests orders property.""" self.assertEqual(self.ce.orders, list(range(len(self.cutoffs) + 2))) def test_property_to_dataframe(self): """Tests to_dataframe() property.""" df = self.ce.to_dataframe() self.assertIn('radius', df.columns) self.assertIn('order', df.columns) self.assertIn('eci', df.columns) self.assertEqual(len(df), len(self.parameters)) def test_get__clusterspace_copy(self): """Tests get cluster space copy.""" self.assertEqual(str(self.ce.get_cluster_space_copy()), str(self.cs)) def test_property_parameters(self): """Tests parameters properties.""" self.assertEqual(list(self.ce.parameters), list(self.parameters)) def test_len(self): """Tests len functionality.""" self.assertEqual(self.ce.__len__(), len(self.parameters)) def test_read_write(self): """Tests read and write functionalities.""" # save to file temp_file = tempfile.NamedTemporaryFile() self.ce.write(temp_file.name) # read from file temp_file.seek(0) ce_read = ClusterExpansion.read(temp_file.name) # check cluster space self.assertEqual(self.cs._input_structure, ce_read._cluster_space._input_structure) self.assertEqual(self.cs._cutoffs, ce_read._cluster_space._cutoffs) self.assertEqual(self.cs._input_chemical_symbols, ce_read._cluster_space._input_chemical_symbols) # check parameters self.assertIsInstance(ce_read.parameters, np.ndarray) self.assertEqual(list(ce_read.parameters), list(self.parameters)) # check metadata self.assertEqual(len(self.ce.metadata), len(ce_read.metadata)) self.assertSequenceEqual(sorted(self.ce.metadata.keys()), sorted(ce_read.metadata.keys())) for key in self.ce.metadata.keys(): self.assertEqual(self.ce.metadata[key], ce_read.metadata[key]) def test_read_write_pruned(self): """Tests read and write functionalities.""" # save to file temp_file = tempfile.NamedTemporaryFile() self.ce.prune(indices=[2, 3]) self.ce.prune(tol=3) pruned_params = self.ce.parameters pruned_cs_len = len(self.ce._cluster_space) self.ce.write(temp_file.name) # read from file temp_file.seek(0) ce_read = ClusterExpansion.read(temp_file.name) params_read = ce_read.parameters cs_len_read = len(ce_read._cluster_space) # check cluster space self.assertEqual(cs_len_read, pruned_cs_len) self.assertEqual(list(params_read), list(pruned_params)) def test_prune_cluster_expansion(self): """Tests pruning cluster expansion.""" len_before = len(self.ce) self.ce.prune() len_after = len(self.ce) self.assertEqual(len_before, len_after) # Set all parameters to zero except three self.ce._parameters = np.array([0.0] * len_after) self.ce._parameters[0] = 1.0 self.ce._parameters[1] = 2.0 self.ce._parameters[2] = 0.5 self.ce.prune() self.assertEqual(len(self.ce), 3) self.assertNotEqual(len(self.ce), len_after) def test_prune_cluster_expansion_tol(self): """Tests pruning cluster expansion with tolerance.""" len_before = len(self.ce) self.ce.prune() len_after = len(self.ce) self.assertEqual(len_before, len_after) # Set all parameters to zero except two, one of which is # non-zero but below the tolerance self.ce._parameters = np.array([0.0] * len_after) self.ce._parameters[0] = 1.0 self.ce._parameters[1] = 0.01 self.ce.prune(tol=0.02) self.assertEqual(len(self.ce), 1) self.assertNotEqual(len(self.ce), len_after) def test_prune_pairs(self): """Tests pruning pairs only.""" df = self.ce.to_dataframe() pair_indices = df.index[df['order'] == 2].tolist() self.ce.prune(indices=pair_indices) df_new = self.ce.to_dataframe() pair_indices_new = df_new.index[df_new['order'] == 2].tolist() self.assertEqual(pair_indices_new, []) def test_prune_zerolet(self): """Tests pruning zerolet.""" with self.assertRaises(ValueError) as context: self.ce.prune(indices=[0]) self.assertTrue('zerolet may not be pruned' in str(context.exception)) def test_plot_ecis(self): """Tests plot_ecis.""" self.ce.plot_ecis() def test_repr(self): """Tests repr functionality.""" retval = self.ce.__repr__() target = """ ================================================ Cluster Expansion ================================================= space group : Fm-3m (225) chemical species : ['Au', 'Pd'] (sublattice A) cutoffs : 3.0000 3.0000 3.0000 total number of parameters : 5 number of parameters by order : 0= 1 1= 1 2= 1 3= 1 4= 1 fractional_position_tolerance : 2e-06 position_tolerance : 1e-05 symprec : 1e-05 total number of nonzero parameters : 4 number of nonzero parameters by order : 0= 0 1= 1 2= 1 3= 1 4= 1 -------------------------------------------------------------------------------------------------------------------- index | order | radius | multiplicity | orbit_index | multi_component_vector | sublattices | parameter | ECI -------------------------------------------------------------------------------------------------------------------- 0 | 0 | 0.0000 | 1 | -1 | . | . | 0 | 0 1 | 1 | 0.0000 | 1 | 0 | [0] | A | 1 | 1 2 | 2 | 1.4425 | 6 | 1 | [0, 0] | A-A | 2 | 0.333 3 | 3 | 1.6657 | 8 | 2 | [0, 0, 0] | A-A-A | 3 | 0.375 4 | 4 | 1.7667 | 2 | 3 | [0, 0, 0, 0] | A-A-A-A | 4 | 2 ==================================================================================================================== """ # noqa self.assertEqual(strip_surrounding_spaces(target), strip_surrounding_spaces(retval)) def test_get_string_representation(self): """Tests _get_string_representation functionality.""" retval = self.ce._get_string_representation(print_threshold=2, print_minimum=1) target = """ ================================================ Cluster Expansion ================================================= space group : Fm-3m (225) chemical species : ['Au', 'Pd'] (sublattice A) cutoffs : 3.0000 3.0000 3.0000 total number of parameters : 5 number of parameters by order : 0= 1 1= 1 2= 1 3= 1 4= 1 fractional_position_tolerance : 2e-06 position_tolerance : 1e-05 symprec : 1e-05 total number of nonzero parameters : 4 number of nonzero parameters by order : 0= 0 1= 1 2= 1 3= 1 4= 1 -------------------------------------------------------------------------------------------------------------------- index | order | radius | multiplicity | orbit_index | multi_component_vector | sublattices | parameter | ECI -------------------------------------------------------------------------------------------------------------------- 0 | 0 | 0.0000 | 1 | -1 | . | . | 0 | 0 ... 4 | 4 | 1.7667 | 2 | 3 | [0, 0, 0, 0] | A-A-A-A | 4 | 2 ==================================================================================================================== """ # noqa self.assertEqual(strip_surrounding_spaces(target), strip_surrounding_spaces(retval)) def test_print_overview(self): """Tests print_overview functionality.""" with StringIO() as capturedOutput: sys.stdout = capturedOutput # redirect stdout self.ce.print_overview() sys.stdout = sys.__stdout__ # reset redirect self.assertTrue('Cluster Expansion' in capturedOutput.getvalue())