Example #1
0
class TestGroundStateFinderTwoActiveSublattices(unittest.TestCase):
    """Container for test of the class functionality for a system with
    two active sublattices."""
    def __init__(self, *args, **kwargs):
        super(TestGroundStateFinderTwoActiveSublattices,
              self).__init__(*args, **kwargs)
        a = 4.0
        self.chemical_symbols = [['Au', 'Pd'], ['Li', 'Na']]
        self.cutoffs = [3.0]
        structure_prim = bulk(self.chemical_symbols[0][0], a=a)
        structure_prim.append(
            Atom(self.chemical_symbols[1][0], position=(a / 2, a / 2, a / 2)))
        structure_prim.wrap()
        self.structure_prim = structure_prim
        self.cs = ClusterSpace(self.structure_prim, self.cutoffs,
                               self.chemical_symbols)
        parameters = [0.1, -0.45, 0.333, 2, -1.42, 0.98]
        self.ce = ClusterExpansion(self.cs, parameters)
        self.all_possible_structures = []
        self.supercell = self.structure_prim.repeat(2)
        self.sl1_indices = [
            s for s, sym in enumerate(self.supercell.get_chemical_symbols())
            if sym == self.chemical_symbols[0][0]
        ]
        self.sl2_indices = [
            s for s, sym in enumerate(self.supercell.get_chemical_symbols())
            if sym == self.chemical_symbols[1][0]
        ]
        for i in self.sl1_indices:
            for j in self.sl2_indices:
                structure = self.supercell.copy()
                structure.symbols[i] = self.chemical_symbols[0][1]
                structure.symbols[j] = self.chemical_symbols[1][1]
                self.all_possible_structures.append(structure)

    def shortDescription(self):
        """Silences unittest from printing the docstrings in test cases."""
        return None

    def setUp(self):
        """Setup before each test."""
        self.gsf = icet.tools.ground_state_finder.GroundStateFinder(
            self.ce, self.supercell, verbose=False)

    def test_init(self):
        """Tests that initialization of tested class work."""
        # initialize from ClusterExpansion instance
        gsf = icet.tools.ground_state_finder.GroundStateFinder(self.ce,
                                                               self.supercell,
                                                               verbose=False)
        self.assertIsInstance(gsf,
                              icet.tools.ground_state_finder.GroundStateFinder)

    def test_get_ground_state(self):
        """Tests get_ground_state functionality."""
        target_val = min([
            self.ce.predict(structure)
            for structure in self.all_possible_structures
        ])

        # Provide counts for the first/first species
        species_count = {
            self.chemical_symbols[0][0]: len(self.sl1_indices) - 1,
            self.chemical_symbols[1][0]: len(self.sl2_indices) - 1
        }
        ground_state = self.gsf.get_ground_state(species_count=species_count)
        predicted_species00 = self.ce.predict(ground_state)
        self.assertEqual(predicted_species00, target_val)

        # Provide counts for the first/second species
        species_count = {
            self.chemical_symbols[0][0]: len(self.sl1_indices) - 1,
            self.chemical_symbols[1][1]: 1
        }
        ground_state = self.gsf.get_ground_state(species_count=species_count)
        predicted_species01 = self.ce.predict(ground_state)
        self.assertEqual(predicted_species01, predicted_species00)

        # Provide counts for second/second species
        species_count = {
            self.chemical_symbols[0][1]: 1,
            self.chemical_symbols[1][1]: 1
        }
        ground_state = self.gsf.get_ground_state(species_count=species_count)
        predicted_species11 = self.ce.predict(ground_state)
        self.assertEqual(predicted_species11, predicted_species01)

    def _test_ground_state_cluster_vectors_in_database(self, db_name):
        """Tests get_ground_state functionality by comparing the cluster
        vectors for the structures in the databases."""

        filename = inspect.getframeinfo(inspect.currentframe()).filename
        path = os.path.dirname(os.path.abspath(filename))
        db = db_connect(os.path.join(path, db_name))

        # Select the structure set with the lowest pairwise correlations
        selections = ['id={}'.format(i) for i in [7, 13, 15, 62, 76]]
        for selection in selections:
            row = db.get(selection)
            structure = row.toatoms()
            target_cluster_vector = self.cs.get_cluster_vector(structure)
            species_count = {
                self.chemical_symbols[0][0]:
                structure.get_chemical_symbols().count(
                    self.chemical_symbols[0][0]),
                self.chemical_symbols[1][0]:
                structure.get_chemical_symbols().count(
                    self.chemical_symbols[1][0])
            }
            ground_state = self.gsf.get_ground_state(
                species_count=species_count)
            gs_cluster_vector = self.cs.get_cluster_vector(ground_state)
            mean_diff = np.mean(abs(target_cluster_vector - gs_cluster_vector))
            self.assertLess(mean_diff, 1e-8)

    def test_ground_state_cluster_vectors(self):
        """Tests get_ground_state functionality by comparing the cluster
        vectors for ground states obtained from simulated annealing."""
        self._test_ground_state_cluster_vectors_in_database(
            '../../../structure_databases/annealing_ground_states.db')

    def test_get_ground_state_fails_for_faulty_species_to_count(self):
        """Tests that get_ground_state fails if species_to_count is faulty."""
        # Check that get_ground_state fails if counts are provided for a both species on one
        # of the active sublattices
        species_count = {
            self.chemical_symbols[0][0]: len(self.sl1_indices) - 1,
            self.chemical_symbols[0][1]: 1,
            self.chemical_symbols[1][1]: 1
        }
        with self.assertRaises(ValueError) as cm:
            self.gsf.get_ground_state(species_count=species_count)
        self.assertTrue(
            'Provide counts for at most one of the species on each active sublattice '
            '({}), not {}!'.format(self.gsf._active_species,
                                   list(species_count.keys())) in str(
                                       cm.exception))

        # Check that get_ground_state fails if the count exceeds the number sites on the first
        # sublattice
        faulty_species = self.chemical_symbols[0][1]
        faulty_count = len(self.supercell)
        species_count = {
            faulty_species: faulty_count,
            self.chemical_symbols[1][1]: 1
        }
        n_active_sites = len([
            sym for sym in self.supercell.get_chemical_symbols()
            if sym == self.chemical_symbols[0][0]
        ])
        with self.assertRaises(ValueError) as cm:
            self.gsf.get_ground_state(species_count=species_count)
        self.assertTrue(
            'The count for species {} ({}) must be a positive integer and cannot '
            'exceed the number of sites on the active sublattice '
            '({})'.format(faulty_species, faulty_count, n_active_sites) in str(
                cm.exception))

        # Check that get_ground_state fails if the count exceeds the number sites on the second
        # sublattice
        faulty_species = self.chemical_symbols[1][1]
        faulty_count = len(self.supercell)
        species_count = {
            faulty_species: faulty_count,
            self.chemical_symbols[0][1]: 1
        }
        n_active_sites = len([
            sym for sym in self.supercell.get_chemical_symbols()
            if sym == self.chemical_symbols[1][0]
        ])
        with self.assertRaises(ValueError) as cm:
            self.gsf.get_ground_state(species_count=species_count)
        self.assertTrue(
            'The count for species {} ({}) must be a positive integer and cannot '
            'exceed the number of sites on the active sublattice '
            '({})'.format(faulty_species, faulty_count, n_active_sites) in str(
                cm.exception))

    def test_get_ground_state_passes_for_partial_species_to_count(self):
        # Check that get_ground_state passes if a single count is provided
        species_count = {self.chemical_symbols[0][1]: 1}
        self.gsf.get_ground_state(species_count=species_count)

        # Check that get_ground_state passes no counts are provided
        gs = self.gsf.get_ground_state()
        self.assertEqual(gs.get_chemical_formula(), "Au8Li8")
Example #2
0
class TestGroundStateFinderTriplets(unittest.TestCase):
    """Container for test of the class functionality for a system with
    triplets."""
    def __init__(self, *args, **kwargs):
        super(TestGroundStateFinderTriplets, self).__init__(*args, **kwargs)
        self.chemical_symbols = ['Au', 'Pd']
        self.cutoffs = [3.0, 3.0]
        structure_prim = fcc111(self.chemical_symbols[0],
                                a=4.0,
                                size=(1, 1, 6),
                                vacuum=10,
                                periodic=True)
        structure_prim.wrap()
        self.structure_prim = structure_prim
        self.cs = ClusterSpace(self.structure_prim, self.cutoffs,
                               self.chemical_symbols)
        parameters = [0.0] * 4 + [0.1] * 6 + [-0.02] * 11
        self.ce = ClusterExpansion(self.cs, parameters)
        self.all_possible_structures = []
        self.supercell = self.structure_prim.repeat((2, 2, 1))
        for i in range(len(self.supercell)):
            structure = self.supercell.copy()
            structure.symbols[i] = self.chemical_symbols[1]
            self.all_possible_structures.append(structure)

    def shortDescription(self):
        """Silences unittest from printing the docstrings in test cases."""
        return None

    def setUp(self):
        """Setup before each test."""
        self.gsf = icet.tools.ground_state_finder.GroundStateFinder(
            self.ce, self.supercell, verbose=False)

    def test_init(self):
        """Tests that initialization of tested class work."""
        # initialize from ClusterExpansion instance
        gsf = icet.tools.ground_state_finder.GroundStateFinder(self.ce,
                                                               self.supercell,
                                                               verbose=False)
        self.assertIsInstance(gsf,
                              icet.tools.ground_state_finder.GroundStateFinder)

    def test_get_ground_state(self):
        """Tests get_ground_state functionality."""
        target_val = min([
            self.ce.predict(structure)
            for structure in self.all_possible_structures
        ])

        # Provide counts for first species
        species_count = {self.chemical_symbols[0]: len(self.supercell) - 1}
        ground_state = self.gsf.get_ground_state(species_count=species_count,
                                                 threads=1)
        predicted_species0 = self.ce.predict(ground_state)
        self.assertEqual(predicted_species0, target_val)

        # Provide counts for second species
        species_count = {self.chemical_symbols[1]: 1}
        ground_state = self.gsf.get_ground_state(species_count=species_count,
                                                 threads=1)
        predicted_species1 = self.ce.predict(ground_state)
        self.assertEqual(predicted_species0, predicted_species1)

        # Check that get_ground_state finds 50-50 mix when no counts are provided
        gsf = icet.tools.ground_state_finder.GroundStateFinder(self.ce,
                                                               self.supercell,
                                                               verbose=False)
        gs = gsf.get_ground_state(threads=1)
        self.assertEqual(gs.get_chemical_formula(), "Au12Pd12")

        # Ensure that an exception is raised when no solution is found
        gsf = icet.tools.ground_state_finder.GroundStateFinder(
            self.ce, self.supercell, solver_name='CBC', verbose=False)
        species_count = {self.chemical_symbols[1]: 1}
        with self.assertRaises(Exception) as cm:
            gsf.get_ground_state(species_count=species_count,
                                 max_seconds=0.0,
                                 threads=1)
        self.assertTrue('Optimization failed' in str(cm.exception))
Example #3
0
class TestGroundStateFinderZeroParameter(unittest.TestCase):
    """Container for test of the class functionality for a system with a zero parameter."""
    def __init__(self, *args, **kwargs):
        super(TestGroundStateFinderZeroParameter,
              self).__init__(*args, **kwargs)
        self.chemical_symbols = ['Ag', 'Au']
        self.cutoffs = [4.3]
        self.structure_prim = bulk(self.chemical_symbols[1], a=4.0)
        self.cs = ClusterSpace(self.structure_prim, self.cutoffs,
                               self.chemical_symbols)
        nonzero_ce = ClusterExpansion(self.cs, [0, 0, 0.1, -0.02])
        lolg = LocalOrbitListGenerator(
            self.cs.orbit_list, Structure.from_atoms(self.structure_prim),
            self.cs.fractional_position_tolerance)
        full_orbit_list = lolg.generate_full_orbit_list()
        binary_parameters_zero = transform_parameters(self.structure_prim,
                                                      full_orbit_list,
                                                      nonzero_ce.parameters)
        binary_parameters_zero[1] = 0
        A = get_transformation_matrix(self.structure_prim, full_orbit_list)
        Ainv = np.linalg.inv(A)
        zero_parameters = np.dot(Ainv, binary_parameters_zero)
        self.ce = ClusterExpansion(self.cs, zero_parameters)
        self.all_possible_structures = []
        self.supercell = self.structure_prim.repeat(2)
        for i in range(len(self.supercell)):
            structure = self.supercell.copy()
            structure.symbols[i] = self.chemical_symbols[0]
            self.all_possible_structures.append(structure)

    def shortDescription(self):
        """Silences unittest from printing the docstrings in test cases."""
        return None

    def setUp(self):
        """Setup before each test."""
        self.gsf = icet.tools.ground_state_finder.GroundStateFinder(
            self.ce, self.supercell, verbose=False)

    def test_init(self):
        """Tests that initialization of tested class work."""
        # initialize from ClusterExpansion instance
        gsf = icet.tools.ground_state_finder.GroundStateFinder(self.ce,
                                                               self.supercell,
                                                               verbose=False)
        self.assertIsInstance(gsf,
                              icet.tools.ground_state_finder.GroundStateFinder)

    def test_get_ground_state(self):
        """Tests get_ground_state functionality."""
        target_val = min([
            self.ce.predict(structure)
            for structure in self.all_possible_structures
        ])

        # Provide counts for first species
        species_count = {self.chemical_symbols[0]: 1}
        ground_state = self.gsf.get_ground_state(species_count=species_count,
                                                 threads=1)
        predicted_species0 = self.ce.predict(ground_state)
        self.assertEqual(predicted_species0, target_val)

        # Provide counts for second species
        species_count = {self.chemical_symbols[1]: len(self.supercell) - 1}
        ground_state = self.gsf.get_ground_state(species_count=species_count,
                                                 threads=1)
        predicted_species1 = self.ce.predict(ground_state)
        self.assertEqual(predicted_species0, predicted_species1)
Example #4
0
class TestGroundStateFinderInactiveSublatticeSameSpecies(unittest.TestCase):
    """Container for test of the class functionality for a system with an
    inactive sublattice occupied by a species found on the active
    sublattice."""
    def __init__(self, *args, **kwargs):
        super(TestGroundStateFinderInactiveSublatticeSameSpecies,
              self).__init__(*args, **kwargs)
        self.chemical_symbols = [['Ag', 'Au'], ['Ag']]
        self.cutoffs = [4.3]
        a = 4.0
        structure_prim = bulk(self.chemical_symbols[0][1], a=a)
        structure_prim.append(
            Atom(self.chemical_symbols[1][0], position=(a / 2, a / 2, a / 2)))
        self.structure_prim = structure_prim
        self.cs = ClusterSpace(self.structure_prim, self.cutoffs,
                               self.chemical_symbols)
        self.ce = ClusterExpansion(self.cs, [0, 0, 0.1, -0.02])
        self.all_possible_structures = []
        self.supercell = self.structure_prim.repeat(2)
        sublattices = self.cs.get_sublattices(self.supercell)
        self.n_active_sites = [
            len(subl.indices) for subl in sublattices.active_sublattices
        ]
        for i, sym in enumerate(self.supercell.get_chemical_symbols()):
            if sym not in self.chemical_symbols[0]:
                continue
            structure = self.supercell.copy()
            structure.symbols[i] = self.chemical_symbols[0][0]
            self.all_possible_structures.append(structure)

    def shortDescription(self):
        """Silences unittest from printing the docstrings in test cases."""
        return None

    def setUp(self):
        """Setup before each test."""
        self.gsf = icet.tools.ground_state_finder.GroundStateFinder(
            self.ce, self.supercell, verbose=False)

    def test_init(self):
        """Tests that initialization of tested class work."""
        # initialize from ClusterExpansion instance
        gsf = icet.tools.ground_state_finder.GroundStateFinder(self.ce,
                                                               self.supercell,
                                                               verbose=False)
        self.assertIsInstance(gsf,
                              icet.tools.ground_state_finder.GroundStateFinder)

    def test_get_ground_state(self):
        """Tests get_ground_state functionality."""
        target_val = min([
            self.ce.predict(structure)
            for structure in self.all_possible_structures
        ])

        # Provide counts for first species
        species_count = {self.chemical_symbols[0][0]: 1}
        ground_state = self.gsf.get_ground_state(species_count=species_count,
                                                 threads=1)
        predicted_species0 = self.ce.predict(ground_state)
        self.assertEqual(predicted_species0, target_val)

        species_count = {
            self.chemical_symbols[0][1]: self.n_active_sites[0] - 1
        }
        ground_state = self.gsf.get_ground_state(species_count=species_count,
                                                 threads=1)
        predicted_species1 = self.ce.predict(ground_state)
        self.assertEqual(predicted_species0, predicted_species1)

    def test_get_ground_state_fails_for_faulty_species_to_count(self):
        """Tests that get_ground_state fails if species_to_count is faulty."""
        # Check that get_ground_state fails if counts are provided for multiple species
        species_count = {
            self.chemical_symbols[0][0]: 1,
            self.chemical_symbols[0][1]: self.n_active_sites[0] - 1
        }
        with self.assertRaises(ValueError) as cm:
            self.gsf.get_ground_state(species_count=species_count)
        self.assertTrue(
            'Provide counts for at most one of the species on each active sublattice '
            '({}), not {}!'.format(self.gsf._active_species,
                                   list(species_count.keys())) in str(
                                       cm.exception))

        # Check that get_ground_state fails if counts are provided for a
        # species not found on the active sublattice
        species_count = {'H': 1}
        with self.assertRaises(ValueError) as cm:
            self.gsf.get_ground_state(species_count=species_count)
        self.assertTrue(
            'The species {} is not present on any of the active sublattices'
            ' ({})'.format(
                list(species_count.keys())[0],
                self.gsf._active_species) in str(cm.exception))

        # Check that get_ground_state fails if the count exceeds the number sites on the active
        # sublattice
        faulty_species = self.chemical_symbols[0][0]
        faulty_count = len(self.supercell)
        species_count = {faulty_species: faulty_count}
        n_active_sites = len([
            sym for sym in self.supercell.get_chemical_symbols()
            if sym == self.chemical_symbols[0][1]
        ])
        with self.assertRaises(ValueError) as cm:
            self.gsf.get_ground_state(species_count=species_count)
        self.assertTrue(
            'The count for species {} ({}) must be a positive integer and cannot '
            'exceed the number of sites on the active sublattice '
            '({})'.format(faulty_species, faulty_count, n_active_sites) in str(
                cm.exception))

    def test_create_cluster_maps(self):
        """Tests _create_cluster_maps functionality """
        gsf = icet.tools.ground_state_finder.GroundStateFinder(self.ce,
                                                               self.supercell,
                                                               verbose=False)
        gsf._create_cluster_maps(self.structure_prim)

        # Test cluster to sites map
        target = [[0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0],
                  [0, 0], [0, 0]]
        self.assertEqual(target, gsf._cluster_to_sites_map)

        # Test cluster to orbit map
        target = [0, 1, 1, 1, 1, 1, 1, 2, 2, 2]
        self.assertEqual(target, gsf._cluster_to_orbit_map)

        # Test ncluster per orbit map
        target = [1, 1, 6, 3]
        self.assertEqual(target, gsf._nclusters_per_orbit)
Example #5
0
class TestGroundStateFinder(unittest.TestCase):
    """Container for test of the class functionality."""
    def __init__(self, *args, **kwargs):
        super(TestGroundStateFinder, self).__init__(*args, **kwargs)
        self.chemical_symbols = ['Ag', 'Au']
        self.cutoffs = [4.3]
        self.structure_prim = bulk(self.chemical_symbols[1], a=4.0)
        self.cs = ClusterSpace(self.structure_prim, self.cutoffs,
                               self.chemical_symbols)
        self.ce = ClusterExpansion(self.cs, [0, 0, 0.1, -0.02])
        self.all_possible_structures = []
        self.supercell = self.structure_prim.repeat(2)
        for i in range(len(self.supercell)):
            structure = self.supercell.copy()
            structure.symbols[i] = self.chemical_symbols[0]
            self.all_possible_structures.append(structure)

    def shortDescription(self):
        """Silences unittest from printing the docstrings in test cases."""
        return None

    def setUp(self):
        """Setup before each test."""
        self.gsf = icet.tools.ground_state_finder.GroundStateFinder(
            self.ce, self.supercell, verbose=False)

    def test_mip_import(self):
        """Tests the Python-MIP import statement"""
        # Test that an error is raised if Python-MIP is not installed
        with self.assertRaises(ImportError) as cm:
            with mock.patch.dict(sys.modules, {'mip': None}):
                importlib.reload(icet.tools.ground_state_finder)
        self.assertTrue(
            'Python-MIP (https://python-mip.readthedocs.io/en/latest/) is required in '
            'order to use the ground state finder.' in str(cm.exception))

        # Test that an error is raised if the Python-MIP version is not sufficiently recent
        with self.assertRaises(VersionConflict) as cm:
            with mock.patch('mip.constants.VERSION', '1.6.2'):
                importlib.reload(icet.tools.ground_state_finder)
        self.assertTrue(
            'Python-MIP version 1.6.3 or later is required in order to use the ground '
            'state finder.' in str(cm.exception))

    def test_init(self):
        """Tests that initialization of tested class work."""
        # initialize from GroundStateFinder instance
        gsf = icet.tools.ground_state_finder.GroundStateFinder(self.ce,
                                                               self.supercell,
                                                               verbose=False)
        self.assertIsInstance(gsf,
                              icet.tools.ground_state_finder.GroundStateFinder)

    def test_init_solver(self):
        """Tests that initialization of tested class work."""
        # initialize from GroundStateFinder instance
        # Set the solver explicitely
        gsf = icet.tools.ground_state_finder.GroundStateFinder(
            self.ce, self.supercell, solver_name='CBC', verbose=False)
        self.assertEqual('CBC', gsf._model.solver_name.upper())

    def test_init_fails_for_ternary_with_one_active_sublattice(self):
        """Tests that initialization fails for a ternary system with one active
        sublattice."""
        chemical_symbols = ['Au', 'Ag', 'Pd']
        cs = ClusterSpace(self.structure_prim,
                          cutoffs=self.cutoffs,
                          chemical_symbols=chemical_symbols)
        ce = ClusterExpansion(cs, [0.0] * len(cs))
        with self.assertRaises(NotImplementedError) as cm:
            icet.tools.ground_state_finder.GroundStateFinder(ce,
                                                             self.supercell,
                                                             verbose=False)
        self.assertTrue(
            'Currently, systems with more than two allowed species on any sublattice '
            'are not supported.' in str(cm.exception))

    def test_optimization_status_property(self):
        """Tests the optimization_status property."""

        # Check that the optimization_status is None initially
        self.assertIsNone(self.gsf.optimization_status)

        # Check that the optimization_status is OPTIMAL if a ground state is found
        species_count = {self.chemical_symbols[0]: 1}
        self.gsf.get_ground_state(species_count=species_count, threads=1)
        self.assertEqual(str(self.gsf.optimization_status),
                         'OptimizationStatus.OPTIMAL')

    def test_model_property(self):
        """Tests the model property."""
        self.assertEqual(self.gsf.model.name, 'CE')

    def test_get_ground_state(self):
        """Tests get_ground_state functionality."""
        target_val = min([
            self.ce.predict(structure)
            for structure in self.all_possible_structures
        ])

        # Provide counts for first species
        species_count = {self.chemical_symbols[0]: 1}
        ground_state = self.gsf.get_ground_state(species_count=species_count,
                                                 threads=1)
        predicted_species0 = self.ce.predict(ground_state)
        self.assertEqual(predicted_species0, target_val)

        # Provide counts for second species
        species_count = {self.chemical_symbols[1]: len(self.supercell) - 1}
        ground_state = self.gsf.get_ground_state(species_count=species_count,
                                                 threads=1)
        predicted_species1 = self.ce.predict(ground_state)
        self.assertEqual(predicted_species0, predicted_species1)

        # Set the maximum run time
        species_count = {self.chemical_symbols[0]: 1}
        ground_state = self.gsf.get_ground_state(species_count=species_count,
                                                 max_seconds=0.5,
                                                 threads=1)
        predicted_max_seconds = self.ce.predict(ground_state)
        self.assertGreaterEqual(predicted_max_seconds, predicted_species0)

    def test_get_ground_state_fails_for_faulty_species_to_count(self):
        """Tests that get_ground_state fails if species_to_count is faulty."""
        # Check that get_ground_state fails if counts are provided for multiple
        # species
        species_count = {
            self.chemical_symbols[0]: 1,
            self.chemical_symbols[1]: len(self.supercell) - 1
        }
        with self.assertRaises(ValueError) as cm:
            self.gsf.get_ground_state(species_count=species_count)
        self.assertTrue(
            'Provide counts for at most one of the species on each active sublattice '
            '({}), not {}!'.format(self.gsf._active_species,
                                   list(species_count.keys())) in str(
                                       cm.exception))

        # Check that get_ground_state fails if counts are provided for a
        # species not found on the active sublattice
        species_count = {'H': 1}
        with self.assertRaises(ValueError) as cm:
            self.gsf.get_ground_state(species_count=species_count)
        self.assertTrue(
            'The species {} is not present on any of the active sublattices'
            ' ({})'.format(
                list(species_count.keys())[0],
                self.gsf._active_species) in str(cm.exception))

        # Check that get_ground_state fails if the count exceeds the number sites on the active
        # sublattice
        faulty_species = self.chemical_symbols[0]
        faulty_count = len(self.supercell) + 1
        species_count = {faulty_species: faulty_count}
        with self.assertRaises(ValueError) as cm:
            self.gsf.get_ground_state(species_count=species_count)
        self.assertTrue(
            'The count for species {} ({}) must be a positive integer and cannot '
            'exceed the number of sites on the active sublattice '
            '({})'.format(faulty_species, faulty_count, len(
                self.supercell)) in str(cm.exception))

        # Check that get_ground_state fails if the count is not a positive integer
        species = self.chemical_symbols[0]
        count = -1
        species_count = {species: count}
        with self.assertRaises(ValueError) as cm:
            self.gsf.get_ground_state(species_count=species_count)
        self.assertTrue(
            'The count for species {} ({}) must be a positive integer and cannot '
            'exceed the number of sites on the active sublattice '
            '({})'.format(species, count, len(self.supercell)) in str(
                cm.exception))

    def test_create_cluster_maps(self):
        """Tests _create_cluster_maps functionality """
        gsf = icet.tools.ground_state_finder.GroundStateFinder(self.ce,
                                                               self.supercell,
                                                               verbose=False)
        gsf._create_cluster_maps(self.structure_prim)

        # Test cluster to sites map
        target = [[0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0],
                  [0, 0], [0, 0]]
        self.assertEqual(target, gsf._cluster_to_sites_map)

        # Test cluster to orbit map
        target = [0, 1, 1, 1, 1, 1, 1, 2, 2, 2]
        self.assertEqual(target, gsf._cluster_to_orbit_map)

        # Test ncluster per orbit map
        target = [1, 1, 6, 3]
        self.assertEqual(target, gsf._nclusters_per_orbit)
class TestClusterExpansion(unittest.TestCase):
    """Container for tests of the class functionality."""
    def __init__(self, *args, **kwargs):
        super(TestClusterExpansion, self).__init__(*args, **kwargs)
        self.structure = bulk('Au')
        self.cutoffs = [3.0] * 3
        chemical_symbols = ['Au', 'Pd']
        self.cs = ClusterSpace(self.structure, self.cutoffs, chemical_symbols)

    def shortDescription(self):
        """Silences unittest from printing the docstrings in test cases."""
        return None

    def setUp(self):
        """Setup before each test."""
        params_len = len(self.cs)
        self.parameters = np.arange(params_len)
        self.ce = ClusterExpansion(self.cs, self.parameters)

    def test_init(self):
        """Tests that initialization works."""
        self.assertIsInstance(self.ce, ClusterExpansion)

        # test whether method raises Exception
        with self.assertRaises(ValueError) as context:
            ClusterExpansion(self.cs, [0.0])
        self.assertTrue('cluster_space (5) and parameters (1) must have the'
                        ' same length' in str(context.exception))

    def test_predict(self):
        """Tests predict function."""
        predicted_val = self.ce.predict(self.structure)
        self.assertEqual(predicted_val, 10.0)

    def test_property_orders(self):
        """Tests orders property."""
        self.assertEqual(self.ce.orders, list(range(len(self.cutoffs) + 2)))

    def test_property_to_dataframe(self):
        """Tests to_dataframe() property."""
        df = self.ce.to_dataframe()
        self.assertIn('radius', df.columns)
        self.assertIn('order', df.columns)
        self.assertIn('eci', df.columns)
        self.assertEqual(len(df), len(self.parameters))

    def test_get__clusterspace_copy(self):
        """Tests get cluster space copy."""
        self.assertEqual(str(self.ce.get_cluster_space_copy()), str(self.cs))

    def test_property_parameters(self):
        """Tests parameters properties."""
        self.assertEqual(list(self.ce.parameters), list(self.parameters))

    def test_len(self):
        """Tests len functionality."""
        self.assertEqual(self.ce.__len__(), len(self.parameters))

    def test_read_write(self):
        """Tests read and write functionalities."""
        # save to file
        temp_file = tempfile.NamedTemporaryFile()
        self.ce.write(temp_file.name)

        # read from file
        temp_file.seek(0)
        ce_read = ClusterExpansion.read(temp_file.name)

        # check cluster space
        self.assertEqual(self.cs._input_structure,
                         ce_read._cluster_space._input_structure)
        self.assertEqual(self.cs._cutoffs, ce_read._cluster_space._cutoffs)
        self.assertEqual(self.cs._input_chemical_symbols,
                         ce_read._cluster_space._input_chemical_symbols)

        # check parameters
        self.assertIsInstance(ce_read.parameters, np.ndarray)
        self.assertEqual(list(ce_read.parameters), list(self.parameters))

        # check metadata
        self.assertEqual(len(self.ce.metadata), len(ce_read.metadata))
        self.assertSequenceEqual(sorted(self.ce.metadata.keys()),
                                 sorted(ce_read.metadata.keys()))
        for key in self.ce.metadata.keys():
            self.assertEqual(self.ce.metadata[key], ce_read.metadata[key])

    def test_read_write_pruned(self):
        """Tests read and write functionalities."""
        # save to file
        temp_file = tempfile.NamedTemporaryFile()
        self.ce.prune(indices=[2, 3])
        self.ce.prune(tol=3)
        pruned_params = self.ce.parameters
        pruned_cs_len = len(self.ce._cluster_space)
        self.ce.write(temp_file.name)

        # read from file
        temp_file.seek(0)
        ce_read = ClusterExpansion.read(temp_file.name)
        params_read = ce_read.parameters
        cs_len_read = len(ce_read._cluster_space)

        # check cluster space
        self.assertEqual(cs_len_read, pruned_cs_len)
        self.assertEqual(list(params_read), list(pruned_params))

    def test_prune_cluster_expansion(self):
        """Tests pruning cluster expansion."""
        len_before = len(self.ce)
        self.ce.prune()
        len_after = len(self.ce)
        self.assertEqual(len_before, len_after)

        # Set all parameters to zero except three
        self.ce._parameters = np.array([0.0] * len_after)
        self.ce._parameters[0] = 1.0
        self.ce._parameters[1] = 2.0
        self.ce._parameters[2] = 0.5
        self.ce.prune()
        self.assertEqual(len(self.ce), 3)
        self.assertNotEqual(len(self.ce), len_after)

    def test_prune_cluster_expansion_tol(self):
        """Tests pruning cluster expansion with tolerance."""
        len_before = len(self.ce)
        self.ce.prune()
        len_after = len(self.ce)
        self.assertEqual(len_before, len_after)

        # Set all parameters to zero except two, one of which is
        # non-zero but below the tolerance
        self.ce._parameters = np.array([0.0] * len_after)
        self.ce._parameters[0] = 1.0
        self.ce._parameters[1] = 0.01
        self.ce.prune(tol=0.02)
        self.assertEqual(len(self.ce), 1)
        self.assertNotEqual(len(self.ce), len_after)

    def test_prune_pairs(self):
        """Tests pruning pairs only."""
        df = self.ce.to_dataframe()
        pair_indices = df.index[df['order'] == 2].tolist()
        self.ce.prune(indices=pair_indices)

        df_new = self.ce.to_dataframe()
        pair_indices_new = df_new.index[df_new['order'] == 2].tolist()
        self.assertEqual(pair_indices_new, [])

    def test_prune_zerolet(self):
        """Tests pruning zerolet."""
        with self.assertRaises(ValueError) as context:
            self.ce.prune(indices=[0])
        self.assertTrue('zerolet may not be pruned' in str(context.exception))

    def test_plot_ecis(self):
        """Tests plot_ecis."""
        self.ce.plot_ecis()

    def test_repr(self):
        """Tests repr functionality."""
        retval = self.ce.__repr__()
        target = """
================================================ Cluster Expansion =================================================
 space group                            : Fm-3m (225)
 chemical species                       : ['Au', 'Pd'] (sublattice A)
 cutoffs                                : 3.0000 3.0000 3.0000
 total number of parameters             : 5
 number of parameters by order          : 0= 1  1= 1  2= 1  3= 1  4= 1
 fractional_position_tolerance          : 2e-06
 position_tolerance                     : 1e-05
 symprec                                : 1e-05
 total number of nonzero parameters     : 4
 number of nonzero parameters by order  : 0= 0  1= 1  2= 1  3= 1  4= 1
--------------------------------------------------------------------------------------------------------------------
index | order |  radius  | multiplicity | orbit_index | multi_component_vector | sublattices | parameter |    ECI
--------------------------------------------------------------------------------------------------------------------
   0  |   0   |   0.0000 |        1     |      -1     |           .            |      .      |         0 |         0
   1  |   1   |   0.0000 |        1     |       0     |          [0]           |      A      |         1 |         1
   2  |   2   |   1.4425 |        6     |       1     |         [0, 0]         |     A-A     |         2 |     0.333
   3  |   3   |   1.6657 |        8     |       2     |       [0, 0, 0]        |    A-A-A    |         3 |     0.375
   4  |   4   |   1.7667 |        2     |       3     |      [0, 0, 0, 0]      |   A-A-A-A   |         4 |         2
====================================================================================================================
"""  # noqa

        self.assertEqual(strip_surrounding_spaces(target),
                         strip_surrounding_spaces(retval))

    def test_get_string_representation(self):
        """Tests _get_string_representation functionality."""
        retval = self.ce._get_string_representation(print_threshold=2,
                                                    print_minimum=1)
        target = """
================================================ Cluster Expansion =================================================
 space group                            : Fm-3m (225)
 chemical species                       : ['Au', 'Pd'] (sublattice A)
 cutoffs                                : 3.0000 3.0000 3.0000
 total number of parameters             : 5
 number of parameters by order          : 0= 1  1= 1  2= 1  3= 1  4= 1
 fractional_position_tolerance          : 2e-06
 position_tolerance                     : 1e-05
 symprec                                : 1e-05
 total number of nonzero parameters     : 4
 number of nonzero parameters by order  : 0= 0  1= 1  2= 1  3= 1  4= 1
--------------------------------------------------------------------------------------------------------------------
index | order |  radius  | multiplicity | orbit_index | multi_component_vector | sublattices | parameter |    ECI
--------------------------------------------------------------------------------------------------------------------
   0  |   0   |   0.0000 |        1     |      -1     |           .            |      .      |         0 |         0
 ...
   4  |   4   |   1.7667 |        2     |       3     |      [0, 0, 0, 0]      |   A-A-A-A   |         4 |         2
====================================================================================================================
"""  # noqa
        self.assertEqual(strip_surrounding_spaces(target),
                         strip_surrounding_spaces(retval))

    def test_print_overview(self):
        """Tests print_overview functionality."""
        with StringIO() as capturedOutput:
            sys.stdout = capturedOutput  # redirect stdout
            self.ce.print_overview()
            sys.stdout = sys.__stdout__  # reset redirect
            self.assertTrue('Cluster Expansion' in capturedOutput.getvalue())