Beispiel #1
0
    def test_tolerances(self):
        """Tests varyinh tolerances."""
        structure = bulk('Al', crystalstructure='bcc', a=4, cubic=True)
        structure[1].position += 1e-5

        # low tol
        pos_tol = 1e-9
        frac_tol = pos_tol * 10
        symprec = pos_tol
        cutoffs = [8, 8, 8]
        orbit_list = OrbitList(
            structure, cutoffs,
            symprec=symprec, position_tolerance=pos_tol,
            fractional_position_tolerance=frac_tol)
        self.assertEqual(len(orbit_list), 395)

        # high tol
        pos_tol = 1e-3
        frac_tol = pos_tol * 10
        symprec = pos_tol
        cutoffs = [8, 8, 8]
        orbit_list = OrbitList(
            structure, cutoffs,
            symprec=symprec, position_tolerance=pos_tol,
            fractional_position_tolerance=frac_tol)
        self.assertEqual(len(orbit_list), 84)
Beispiel #2
0
    def _test_equivalent_sites(self, structure):
        """
        Tests permutations taken equivalent sites to representative sites.
        """
        cutoffs = [1.6, 1.6]
        orbit_list = OrbitList(
            structure, cutoffs,
            symprec=self.symprec, position_tolerance=self.position_tolerance,
            fractional_position_tolerance=self.fractional_position_tolerance)

        for orbit in orbit_list.orbits:
            match_repr_site = False
            # Take representative sites and translate them into unitcell
            repr_sites = orbit.sites_of_representative_cluster
            # Take equivalent sites and its permutations_to_representative
            for eq_sites, perm in zip(orbit.equivalent_clusters,
                                      orbit.permutations_to_representative):
                trans_eq_sites = orbit_list._get_sites_translated_to_unitcell(eq_sites, False)
                # Permute equivalent sites and get all columns from those sites
                for sites in trans_eq_sites:
                    perm_sites = get_permutation(sites, perm)
                    columns = orbit_list._get_all_columns_from_sites(perm_sites)
                    # Check representative sites can be found in columns
                    if repr_sites in columns:
                        match_repr_site = True
            self.assertTrue(match_repr_site)
    def __init__(self,
                 structure: Atoms,
                 cutoffs: List[float],
                 chemical_symbols: Union[List[str], List[List[str]]],
                 symprec: float = 1e-5,
                 position_tolerance: float = None) -> None:

        if not isinstance(structure, Atoms):
            raise TypeError('Input configuration must be an ASE Atoms object'
                            ', not type {}'.format(type(structure)))
        if not all(structure.pbc):
            raise ValueError('Input structure must have periodic boundary conditions')
        if symprec <= 0:
            raise ValueError('symprec must be a positive number')

        self._config = {'symprec': symprec}
        self._cutoffs = cutoffs.copy()
        self._input_structure = structure.copy()
        self._input_chemical_symbols = copy.deepcopy(chemical_symbols)
        chemical_symbols = self._get_chemical_symbols()

        self._pruning_history = []

        # set up primitive
        occupied_primitive, primitive_chemical_symbols = get_occupied_primitive_structure(
            self._input_structure, chemical_symbols, symprec=self.symprec)
        self._primitive_chemical_symbols = primitive_chemical_symbols
        assert len(occupied_primitive) == len(primitive_chemical_symbols)

        # derived tolerances
        if position_tolerance is None:
            self._config['position_tolerance'] = symprec
        else:
            if position_tolerance <= 0:
                raise ValueError('position_tolerance must be a positive number')
            self._config['position_tolerance'] = position_tolerance
        effective_box_size = abs(np.linalg.det(occupied_primitive.cell)) ** (1 / 3)
        tol = self.position_tolerance / effective_box_size
        tol = min(tol, self._config['position_tolerance'] / 5)
        self._config['fractional_position_tolerance'] = round(tol, -int(floor(log10(abs(tol)))))

        # set up orbit list
        self._orbit_list = OrbitList(
            structure=occupied_primitive,
            cutoffs=self._cutoffs,
            symprec=self.symprec,
            position_tolerance=self.position_tolerance,
            fractional_position_tolerance=self.fractional_position_tolerance)
        self._orbit_list.remove_inactive_orbits(primitive_chemical_symbols)

        # call (base) C++ constructor
        _ClusterSpace.__init__(self,
                               chemical_symbols=primitive_chemical_symbols,
                               orbit_list=self._orbit_list,
                               position_tolerance=self.position_tolerance,
                               fractional_position_tolerance=self.fractional_position_tolerance)
Beispiel #4
0
    def _test_allowed_permutations(self, structure):
        """Tests allowed permutations of orbits in orbit list.

        This test works in the following fashion.
        For each orbit in orbit_list:
        1- Translate representative sites to unitcell
        2- Permute translated sites
        3- Get sites from all columns of permutation matrix
        that map simultaneusly to the permuted sites.
        3- If permutation is not allowed then check that any of translated
        sites cannot be found in columns obtained in previous step.
        4- If at least one of translated sites is found in columns
        then append the respective permutation to allowed_perm list.
        5. Check allowed_perm list is equal to orbit.allowed_permutation.
        """
        cutoffs = [1.6, 1.6]
        orbit_list = OrbitList(
            structure, cutoffs,
            symprec=self.symprec, position_tolerance=self.position_tolerance,
            fractional_position_tolerance=self.fractional_position_tolerance)

        for orbit in orbit_list.orbits:
            # Set up all possible permutations
            allowed_perm = []
            all_perm = \
                [list(perm) for perm in permutations(range(orbit.order))]
            # Get representative site of orbit
            repr_sites = orbit.sites_of_representative_cluster
            translated_sites = \
                orbit_list._get_sites_translated_to_unitcell(repr_sites, False)
            for sites in translated_sites:
                for perm in all_perm:
                    # Permute translated sites
                    perm_sites = get_permutation(sites, perm)
                    # Get from all columns those sites at the rows
                    # where permuted sites is found in column1.
                    columns = \
                        orbit_list._get_all_columns_from_sites(perm_sites)
                    # Any translated sites will be found in columns since
                    # permutation is not allowed
                    if perm not in orbit.allowed_permutations:
                        self.assertTrue(
                            any(s not in columns for s in translated_sites))
                    # If translated sites is found then save permutation
                    for s in translated_sites:
                        if s in columns and perm not in allowed_perm:
                            allowed_perm.append(perm)
            # Check all collected permutations match allowed_permutations
            self.assertEqual(sorted(allowed_perm),
                             sorted(orbit.allowed_permutations))
Beispiel #5
0
 def __init__(self, *args, **kwargs):
     super(TestClusterCounts, self).__init__(*args, **kwargs)
     self.structure = bulk('Ni', 'hcp', a=2.0).repeat([2, 1, 1])
     self.structure_prim = bulk('Ni', 'hcp', a=2.0)
     self.structure.set_chemical_symbols('NiFeNi2')
     self.icet_structure = Structure.from_atoms(self.structure)
     self.cutoffs = [2.2]
     self.symprec = 1e-5
     self.position_tolerance = 1e-5
     self.fractional_position_tolerance = 1e-6
     self.orbit_list = OrbitList(self.structure_prim, self.cutoffs,
                                 self.symprec, self.position_tolerance,
                                 self.fractional_position_tolerance)
     self.orbit_list.sort(self.position_tolerance)
 def __init__(self, *args, **kwargs):
     super(TestLocalOrbitListGeneratorHCP, self).__init__(*args, **kwargs)
     prim_structure = bulk('Ni', 'hcp', a=4.0)
     cutoffs = [4.2, 4.2]
     self.symprec = 1e-5
     self.position_tolerance = 1e-5
     self.fractional_position_tolerance = 1e-6
     self.orbit_list = OrbitList(
         prim_structure, cutoffs,
         symprec=self.symprec, position_tolerance=self.position_tolerance,
         fractional_position_tolerance=self.fractional_position_tolerance)
     self.primitive = self.orbit_list.get_primitive_structure()
     super_structure = make_supercell(prim_structure, [[2, 0, 1000],
                                                       [0, 2, 0],
                                                       [0, 0, 2]])
     self.supercell = Structure.from_atoms(super_structure)
Beispiel #7
0
 def test_init(self):
     """Test the different initializers."""
     orbit_list = OrbitList(
         self.structure, self.cutoffs,
         symprec=self.symprec, position_tolerance=self.position_tolerance,
         fractional_position_tolerance=self.fractional_position_tolerance)
     self.assertIsInstance(orbit_list, OrbitList)
    def setUp(self):
        """Instantiate class for each test case."""
        prim_structure = bulk('Al')
        cutoffs = [4.2, 4.2]
        self.orbit_list = OrbitList(
            prim_structure, cutoffs,
            symprec=self.symprec, position_tolerance=self.position_tolerance,
            fractional_position_tolerance=self.fractional_position_tolerance)
        self.primitive = self.orbit_list.get_primitive_structure()
        super_structure = make_supercell(prim_structure, [[2, 0, 1000],
                                                          [0, 2, 0],
                                                          [0, 0, 2]])
        self.supercell = Structure.from_atoms(super_structure)

        self.lolg = LocalOrbitListGenerator(
            self.orbit_list, self.supercell,
            fractional_position_tolerance=self.fractional_position_tolerance)
Beispiel #9
0
 def test_orbit_list_hcp(self):
     """
     Tests orbit list has the right number of singlet and pairs for
     a hcp structure.
     """
     structure = bulk('Ni', 'hcp', a=3.0)
     cutoffs = [3.1]
     orbit_list = OrbitList(
         structure, cutoffs,
         symprec=self.symprec, position_tolerance=self.position_tolerance,
         fractional_position_tolerance=self.fractional_position_tolerance)
     # only one singlet and one pair expected
     self.assertEqual(len(orbit_list), 3)
     # singlet
     singlet = orbit_list.get_orbit(0)
     self.assertEqual(len(singlet), 2)
     # pair has multiplicity equal to 4
     pairs = orbit_list.get_orbit(1)
     self.assertEqual(len(pairs), 6)
     # pair has multiplicity equal to 4
     pairs = orbit_list.get_orbit(2)
     self.assertEqual(len(pairs), 6)
     # not more orbits listed
     with self.assertRaises(IndexError):
         orbit_list.get_orbit(3)
Beispiel #10
0
 def test_orbit_list_bcc(self):
     """
     Tests orbit list has the right number  of singlet and pairs for
     a bcc structure.
     """
     structure = bulk('Al', 'bcc', a=3.0)
     cutoffs = [3.0]
     orbit_list = OrbitList(
         structure, cutoffs,
         symprec=self.symprec, position_tolerance=self.position_tolerance,
         fractional_position_tolerance=self.fractional_position_tolerance)
     # one singlet and two pairs expected
     self.assertEqual(len(orbit_list), 3)
     # singlet
     singlet = orbit_list.get_orbit(0)
     self.assertEqual(len(singlet), 1)
     # first pair has multiplicity equal to 4
     pairs = orbit_list.get_orbit(1)
     self.assertEqual(len(pairs), 4)
     # first pair has multiplicity equal to 3
     pairs = orbit_list.get_orbit(2)
     self.assertEqual(len(pairs), 3)
     # not more orbits listed
     with self.assertRaises(IndexError):
         orbit_list.get_orbit(3)
class TestLocalOrbitListGeneratorHCP(unittest.TestCase):
    """
    Container for test of class functionality for hcp structure,
    which contains two atoms per unitcell.
    """

    def __init__(self, *args, **kwargs):
        super(TestLocalOrbitListGeneratorHCP, self).__init__(*args, **kwargs)
        prim_structure = bulk('Ni', 'hcp', a=4.0)
        cutoffs = [4.2, 4.2]
        self.symprec = 1e-5
        self.position_tolerance = 1e-5
        self.fractional_position_tolerance = 1e-6
        self.orbit_list = OrbitList(
            prim_structure, cutoffs,
            symprec=self.symprec, position_tolerance=self.position_tolerance,
            fractional_position_tolerance=self.fractional_position_tolerance)
        self.primitive = self.orbit_list.get_primitive_structure()
        super_structure = make_supercell(prim_structure, [[2, 0, 1000],
                                                          [0, 2, 0],
                                                          [0, 0, 2]])
        self.supercell = Structure.from_atoms(super_structure)

    def shortDescription(self):
        """Silences unittest from printing the docstrings in test cases."""
        return None

    def setUp(self):
        """Instantiate class for each test case."""
        self.lolg = LocalOrbitListGenerator(
            self.orbit_list, self.supercell, self.fractional_position_tolerance)

    def test_generate_local_orbit_list(self):
        """
        Tests that function generates an orbit list for the given
        offset of the primitive structure.
        """
        unique_offsets = self.lolg._get_unique_primcell_offsets()
        for offset in unique_offsets:
            local_orbit_list = self.lolg.generate_local_orbit_list(offset)
            for orbit_prim, orbit_super in zip(self.orbit_list.orbits,
                                               local_orbit_list.orbits):
                for site_p, site_s in zip(orbit_prim.sites_of_representative_cluster,
                                          orbit_super.sites_of_representative_cluster):
                    site_p.unitcell_offset += offset
                    pos_super = self.supercell.get_position(site_s)
                    pos_prim = self.primitive.get_position(site_p)
                    self.assertTrue(np.all(np.isclose(pos_super, pos_prim)))

    def test_unique_offset_count(self):
        """
        Tests number of unique offsets corresponds to half of the total number
        of atoms in the supercell given that there is two atoms per unitcell.
        """
        self.assertEqual(self.lolg.get_number_of_unique_offsets(),
                         len(self.supercell) / 2)

    def test_get_primitive_to_supercell_map(self):
        """Tests primitive to supercell mapping."""
        unique_offsets = self.lolg._get_unique_primcell_offsets()

        for offset in unique_offsets:
            self.lolg.generate_local_orbit_list(offset)
            mapping = self.lolg._get_primitive_to_supercell_map()
            for sites_prim, sites_super in mapping.items():
                pos_super = self.supercell.get_position(sites_super)
                pos_prim = self.primitive.get_position(sites_prim)
                self.assertTrue(np.all(np.isclose(pos_super, pos_prim)))

    def test_unique_primcell_offsets(self):
        """
        Tests primitive offsets are unique and take to positions that
        match atoms positions in the supercell.
        """
        unique_offsets = self.lolg._get_unique_primcell_offsets()
        super_pos = self.supercell.positions

        for k, offset in enumerate(unique_offsets):
            pos_prim = self.primitive.get_position(LatticeSite(0, offset))
            self.assertTrue(
                np.any(np.isclose(pos_prim, pos) for pos in super_pos))
            for i in range(k + 1, len(unique_offsets)):
                self.assertFalse(np.all(np.isclose(offset, unique_offsets[i])))
Beispiel #12
0
 def setUp(self):
     """Instantiate class before each test."""
     self.orbit_list = OrbitList(
         self.structure, self.cutoffs,
         symprec=self.symprec, position_tolerance=self.position_tolerance,
         fractional_position_tolerance=self.fractional_position_tolerance)
Beispiel #13
0
class TestClusterCounts(unittest.TestCase):
    """ Container for test of the module functionality. """

    def __init__(self, *args, **kwargs):
        super(TestClusterCounts, self).__init__(*args, **kwargs)
        self.structure = bulk('Ni', 'hcp', a=2.0).repeat([2, 1, 1])
        self.structure_prim = bulk('Ni', 'hcp', a=2.0)
        self.structure.set_chemical_symbols('NiFeNi2')
        self.icet_structure = Structure.from_atoms(self.structure)
        self.cutoffs = [2.2]
        self.symprec = 1e-5
        self.position_tolerance = 1e-5
        self.fractional_position_tolerance = 1e-6
        self.orbit_list = OrbitList(self.structure_prim, self.cutoffs,
                                    self.symprec, self.position_tolerance,
                                    self.fractional_position_tolerance)
        self.orbit_list.sort(self.position_tolerance)

    def shortDescription(self):
        """Silences unittest from printing the docstrings in test cases."""
        return None

    def setUp(self):
        """ Sets up an empty cluster counts object. """
        self.cluster_counts = ClusterCounts(self.orbit_list,
                                            self.structure,
                                            self.fractional_position_tolerance)

    def test_count_list_lattice_sites(self):
        """
        Tests whether cluster_counts returns the correct number of pairs
        given a list of lattice neighbors.
        """
        lattice_sites = []
        lattice_sites.append(LatticeSite(0, [0., 0., 0.]))
        lattice_sites.append(LatticeSite(1, [0., 0., 0.]))

        lattice_sites2 = []
        lattice_sites2.append(LatticeSite(0, [0., 0., 0.]))
        lattice_sites2.append(LatticeSite(2, [0., 0., 0.]))

        # The tag is set to -1 in order not to collide with an existing
        # cluster-tag as the latter is used to generate a hash (and defaults to
        # 0, which is certainly already taken).
        cluster = Cluster(self.icet_structure, lattice_sites, tag=-1)

        lattice_neighbors = [lattice_sites, lattice_sites2]

        self.cluster_counts.count(self.icet_structure, lattice_neighbors, cluster, True)
        cluster_map = self.cluster_counts.get_cluster_counts()

        count = cluster_map[cluster]
        self.assertEqual(count, {('Ni', 'Fe'): 1, ('Ni', 'Ni'): 1})

    def test_count_orbit_list(self):
        """Tests cluster_counts given orbits in an orbit list."""
        cluster_singlet = Cluster(self.icet_structure, [], 0)
        cluster_pair = Cluster(self.icet_structure, [], 1)
        clusters = [cluster_singlet, cluster_pair]

        expected_counts = [{('Fe',): 1, ('Ni',): 3},
                           {('Fe', 'Fe'): 1, ('Fe', 'Ni'): 4,
                            ('Ni', 'Ni'): 7}]

        for k, cluster in enumerate(clusters):
            count = self.cluster_counts[cluster]
            self.assertEqual(count, expected_counts[k])

    def test_len(self):
        """Tests total size of counts."""
        self.assertEqual(len(self.cluster_counts),
                         len(self.orbit_list))

    def test_reset(self):
        """Tests reset functionality."""
        self.cluster_counts.reset()
        self.assertEqual(len(self.cluster_counts), 0)

    def test_getitem(self):
        """Tests __getitem__ functionality."""
        # Test with integer as key
        self.assertEqual(self.cluster_counts[2],
                         {('Fe', 'Ni'): 6, ('Ni', 'Ni'): 6})
        # Test with icet Cluster as key
        cluster = list(self.cluster_counts.cluster_counts.keys())[0]
        self.assertEqual(self.cluster_counts[cluster], {
                         ('Fe',): 1, ('Ni',): 3})

    def test_str(self):
        """Tests representation of cluster_counts."""
        retval = self.cluster_counts.__str__()
        target = """
====================== Cluster Counts ======================
Singlet: [0] [] 0.0000
Fe   1
Ni   3

Pair: [0, 0] [2.0] 1.0000
Fe  Fe   1
Fe  Ni   4
Ni  Ni   7

Pair: [0, 0] [2.0] 1.0000
Fe  Ni   6
Ni  Ni   6
============================================================
"""
        self.assertEqual(strip_surrounding_spaces(target),
                         strip_surrounding_spaces(retval))

    def test_get_cluster_counts(self):
        """Tests get_cluster_counts functionality."""
        counts = self.cluster_counts.get_cluster_counts()
        for cluster, cluster_info in counts.items():
            # check size of cluster match the number of elements in counts
            for elements in cluster_info:
                self.assertEqual(len(cluster), len(elements))
Beispiel #14
0
class TestOrbitList(unittest.TestCase):
    """Container for test of the module functionality."""

    def __init__(self, *args, **kwargs):
        super(TestOrbitList, self).__init__(*args, **kwargs)
        self.cutoffs = [4.2]
        self.symprec = 1e-5
        self.position_tolerance = 1e-5
        self.fractional_position_tolerance = 1e-6
        self.structure = bulk('Ag', 'sc', a=4.09)

        # representative clusters for testing
        # for singlet
        self.cluster_singlet = Cluster(
            Structure.from_atoms(self.structure),
            [LatticeSite(0, [0, 0, 0])])
        # for pair
        lattice_sites = [LatticeSite(0, [i, 0, 0]) for i in range(3)]
        self.cluster_pair = Cluster(Structure.from_atoms(self.structure),
                                    [lattice_sites[0], lattice_sites[1]],
                                    True)

    def shortDescription(self):
        """Silences unittest from printing the docstrings in test cases."""
        return None

    def setUp(self):
        """Instantiate class before each test."""
        self.orbit_list = OrbitList(
            self.structure, self.cutoffs,
            symprec=self.symprec, position_tolerance=self.position_tolerance,
            fractional_position_tolerance=self.fractional_position_tolerance)

    def test_init(self):
        """Test the different initializers."""
        orbit_list = OrbitList(
            self.structure, self.cutoffs,
            symprec=self.symprec, position_tolerance=self.position_tolerance,
            fractional_position_tolerance=self.fractional_position_tolerance)
        self.assertIsInstance(orbit_list, OrbitList)

    def test_tolerances(self):
        """Tests varyinh tolerances."""
        structure = bulk('Al', crystalstructure='bcc', a=4, cubic=True)
        structure[1].position += 1e-5

        # low tol
        pos_tol = 1e-9
        frac_tol = pos_tol * 10
        symprec = pos_tol
        cutoffs = [8, 8, 8]
        orbit_list = OrbitList(
            structure, cutoffs,
            symprec=symprec, position_tolerance=pos_tol,
            fractional_position_tolerance=frac_tol)
        self.assertEqual(len(orbit_list), 395)

        # high tol
        pos_tol = 1e-3
        frac_tol = pos_tol * 10
        symprec = pos_tol
        cutoffs = [8, 8, 8]
        orbit_list = OrbitList(
            structure, cutoffs,
            symprec=symprec, position_tolerance=pos_tol,
            fractional_position_tolerance=frac_tol)
        self.assertEqual(len(orbit_list), 84)

    def test_property_matrix_of_equivalent_positions(self):
        """Tests permutation matrix property."""
        matrix_of_equivalent_positions, prim_structure, _ = \
            matrix_of_equivalent_positions_from_structure(self.structure, self.cutoffs[0],
                                                          self.position_tolerance, self.symprec)
        pm_lattice_site = _get_lattice_site_matrix_of_equivalent_positions(
            prim_structure, matrix_of_equivalent_positions,
            fractional_position_tolerance=self.fractional_position_tolerance, prune=True)

        self.assertEqual(self.orbit_list.matrix_of_equivalent_positions, pm_lattice_site)

    def test_add_orbit(self):
        """Tests add_orbit funcionality."""
        orbit = Orbit(self.cluster_pair)
        self.orbit_list.add_orbit(orbit)
        self.assertEqual(len(self.orbit_list), 3)

    def test_get_number_of_nbody_clusters(self):
        """Tests that only a pair is counted in the orbit list."""
        NPairs = self.orbit_list.get_number_of_nbody_clusters(2)
        self.assertEqual(NPairs, 1)

    def test_get_orbit(self):
        """Tests function returns the number of orbits of a given order."""
        # get singlet
        orbit = self.orbit_list.get_orbit(0)
        self.assertEqual(orbit.order, 1)
        # get pair
        orbit = self.orbit_list.get_orbit(1)
        self.assertEqual(orbit.order, 2)
        # check higher order raises an error
        with self.assertRaises(IndexError):
            self.orbit_list.get_orbit(3)

    def test_clear(self):
        """Tests orbit list is empty after calling this function."""
        self.orbit_list.clear()
        with self.assertRaises(IndexError):
            self.orbit_list.get_orbit(0)

    def test_sort(self):
        """Tests orbits in orbit list are sorted."""
        self.orbit_list.sort(self.position_tolerance)

#    def test_find_orbit_index(self):
#        """
#        Tests that orbit index returned for the given representative cluster.
#        """
#        # TODO: test that a non-representative cluster returns -1
#        self.assertEqual(
#            self.orbit_list._find_orbit_index(self.cluster_singlet), 0)
#        self.assertEqual(
#            self.orbit_list._find_orbit_index(self.cluster_pair), 1)

    def test_is_row_taken(self):
        """Tests is_row_taken (private) functionality."""
        taken_rows = set()
        row_indices = [0, 1, 2]
        self.assertFalse(self.orbit_list._is_row_taken(
            taken_rows, row_indices))

        taken_rows = set([tuple(row_indices)])
        self.assertTrue(self.orbit_list._is_row_taken(
            taken_rows, row_indices))

    def test_get_orbit_list(self):
        """Tests a list of orbits is returned from this function."""
        # clusters for testing
        repr_clusters = [self.cluster_singlet, self.cluster_pair]

        for k, orbit in enumerate(self.orbit_list.orbits):
            with self.subTest(orbit=orbit):
                self.assertEqual(str(orbit.representative_cluster),
                                 str(repr_clusters[k]))

    def test_remove_all_orbits(self):
        """Tests removing all orbits."""

        chemical_symbols = [
            ['Al'] * len(self.orbit_list.get_primitive_structure())]
        len_before = len(self.orbit_list)
        self.assertNotEqual(len_before, 0)
        self.orbit_list.remove_inactive_orbits(chemical_symbols)
        len_after = len(self.orbit_list)
        self.assertEqual(len_after, 0)

    def test_get_primitive_structure(self):
        """ Tests get primitive structure functionality. """
        self.assertIsInstance(
            self.orbit_list.get_primitive_structure(), Structure)

    def test_len(self):
        """Tests length of orbit list."""
        self.assertEqual(len(self.orbit_list), 2)

    def test_get_supercell_orbit_list(self):
        """Tests orbit list is returned for the given supercell."""
        # TODO : Tests fails for an actual supercell of the testing structure
        structure_supercell = self.structure.copy()
        orbit_list_super = \
            self.orbit_list.get_supercell_orbit_list(
                structure_supercell, self.position_tolerance)
        orbit_list_super.sort(self.position_tolerance)
        self.orbit_list.sort(self.position_tolerance)
        for k in range(len(orbit_list_super)):
            orbit_super = orbit_list_super.get_orbit(k)
            orbit = self.orbit_list.get_orbit(k)
            self.assertEqual(str(orbit), str(orbit_super))

    def test_translate_sites_to_unitcell(self):
        """Tests the get all translated sites functionality."""
        # no offset site shoud get itself as translated
        sites = [LatticeSite(0, [0, 0, 0])]
        target = [[LatticeSite(0, [0, 0, 0])]]
        self.assertListEqual(
            self.orbit_list._get_sites_translated_to_unitcell(sites, False),
            target)

        # test a singlet site with offset
        sites = [LatticeSite(3, [0, 0, -1])]
        target = [[LatticeSite(3, [0, 0, -1])],
                  [LatticeSite(3, [0, 0, 0])]]
        self.assertListEqual(
            self.orbit_list._get_sites_translated_to_unitcell(sites, False),
            target)

        # sort output
        self.assertListEqual(
            self.orbit_list._get_sites_translated_to_unitcell(sites, True),
            sorted(target))

        # Does it break when the offset is floats?
        sites = [LatticeSite(0, [0.0, 0.0, 0.0])]
        target = [[LatticeSite(0, [0.0, 0.0, 0.0])]]
        self.assertListEqual(
            self.orbit_list._get_sites_translated_to_unitcell(sites, False),
            target)

        # Test two sites with floats
        sites = [LatticeSite(0, [1.0, 0.0, 0.0]),
                 LatticeSite(0, [0.0, 0.0, 0.0])]
        target = [[LatticeSite(0, [0.0, 0.0, 0.0]),
                   LatticeSite(0, [-1., 0.0, 0.0])],
                  sites]
        self.assertListEqual(
            self.orbit_list._get_sites_translated_to_unitcell(sites, False),
            target)

        # Test sites where none is inside unit cell
        sites = [LatticeSite(0, [1.0, 2.0, -1.0]),
                 LatticeSite(2, [2.0, 0.0, 0.0])]

        target = [[LatticeSite(0, [-1.0, 2.0, -1.0]),
                   LatticeSite(2, [0.0, 0.0, 0.0])],
                  [LatticeSite(0, [0.0, 0.0, 0.0]),
                   LatticeSite(2, [1.0, -2.0, 1.0])],
                  sites]
        self.assertListEqual(
            self.orbit_list._get_sites_translated_to_unitcell(sites, False),
            target)

    def test_get_all_columns_from_sites(self):
        """Tests get_all_columns_from_sites functionality."""
        # These sites are first and last elements in column1
        sites = [LatticeSite(0, [0., 0., 0.]),
                 LatticeSite(0, [1., 0., 0.])]

        pm = self.orbit_list.matrix_of_equivalent_positions
        columns = self.orbit_list._get_all_columns_from_sites(sites)
        for i in range(len(pm[0])):
            perm_sites = [pm[0][i], pm[-1][i]]
            translated_sites = \
                self.orbit_list._get_sites_translated_to_unitcell(perm_sites,
                                                                  False)
            for k, sites in enumerate(translated_sites):
                self.assertEqual(columns[k + 2 * i], sites)

    def _test_allowed_permutations(self, structure):
        """Tests allowed permutations of orbits in orbit list.

        This test works in the following fashion.
        For each orbit in orbit_list:
        1- Translate representative sites to unitcell
        2- Permute translated sites
        3- Get sites from all columns of permutation matrix
        that map simultaneusly to the permuted sites.
        3- If permutation is not allowed then check that any of translated
        sites cannot be found in columns obtained in previous step.
        4- If at least one of translated sites is found in columns
        then append the respective permutation to allowed_perm list.
        5. Check allowed_perm list is equal to orbit.allowed_permutation.
        """
        cutoffs = [1.6, 1.6]
        orbit_list = OrbitList(
            structure, cutoffs,
            symprec=self.symprec, position_tolerance=self.position_tolerance,
            fractional_position_tolerance=self.fractional_position_tolerance)

        for orbit in orbit_list.orbits:
            # Set up all possible permutations
            allowed_perm = []
            all_perm = \
                [list(perm) for perm in permutations(range(orbit.order))]
            # Get representative site of orbit
            repr_sites = orbit.sites_of_representative_cluster
            translated_sites = \
                orbit_list._get_sites_translated_to_unitcell(repr_sites, False)
            for sites in translated_sites:
                for perm in all_perm:
                    # Permute translated sites
                    perm_sites = get_permutation(sites, perm)
                    # Get from all columns those sites at the rows
                    # where permuted sites is found in column1.
                    columns = \
                        orbit_list._get_all_columns_from_sites(perm_sites)
                    # Any translated sites will be found in columns since
                    # permutation is not allowed
                    if perm not in orbit.allowed_permutations:
                        self.assertTrue(
                            any(s not in columns for s in translated_sites))
                    # If translated sites is found then save permutation
                    for s in translated_sites:
                        if s in columns and perm not in allowed_perm:
                            allowed_perm.append(perm)
            # Check all collected permutations match allowed_permutations
            self.assertEqual(sorted(allowed_perm),
                             sorted(orbit.allowed_permutations))

    def _test_equivalent_sites(self, structure):
        """
        Tests permutations taken equivalent sites to representative sites.
        """
        cutoffs = [1.6, 1.6]
        orbit_list = OrbitList(
            structure, cutoffs,
            symprec=self.symprec, position_tolerance=self.position_tolerance,
            fractional_position_tolerance=self.fractional_position_tolerance)

        for orbit in orbit_list.orbits:
            match_repr_site = False
            # Take representative sites and translate them into unitcell
            repr_sites = orbit.sites_of_representative_cluster
            # Take equivalent sites and its permutations_to_representative
            for eq_sites, perm in zip(orbit.equivalent_clusters,
                                      orbit.permutations_to_representative):
                trans_eq_sites = orbit_list._get_sites_translated_to_unitcell(eq_sites, False)
                # Permute equivalent sites and get all columns from those sites
                for sites in trans_eq_sites:
                    perm_sites = get_permutation(sites, perm)
                    columns = orbit_list._get_all_columns_from_sites(perm_sites)
                    # Check representative sites can be found in columns
                    if repr_sites in columns:
                        match_repr_site = True
            self.assertTrue(match_repr_site)

    def test_orbit_permutations_for_structure_in_database(self):
        """
        Tests allowed_permutation and equivalent_sites of orbits in orbit_list
        for structure in database (only structures with pbc=True).
        """
        db = ase_connect('structures_for_testing.db')
        for row in db.select('pbc=TTT'):
            structure = row.toatoms()
            with self.subTest(structure_tag=row.tag):
                self._test_allowed_permutations(structure)
                self._test_equivalent_sites(structure)

    def test_orbit_list_fcc(self):
        """
        Tests orbit list has the right number of singlet and pairs for
        a fcc structure.
        """
        structure = bulk('Al', 'fcc', a=3.0)
        cutoffs = [2.5]
        orbit_list = OrbitList(
            structure, cutoffs,
            symprec=self.symprec, position_tolerance=self.position_tolerance,
            fractional_position_tolerance=self.fractional_position_tolerance)
        # only a singlet and a pair are expected
        self.assertEqual(len(orbit_list), 2)
        # singlet
        singlet = orbit_list.get_orbit(0)
        self.assertEqual(len(singlet), 1)
        # pair has multiplicity equal to 4
        pairs = orbit_list.get_orbit(1)
        self.assertEqual(len(pairs), 6)
        # not more orbits listed
        with self.assertRaises(IndexError):
            orbit_list.get_orbit(2)

    def test_orbit_list_bcc(self):
        """
        Tests orbit list has the right number  of singlet and pairs for
        a bcc structure.
        """
        structure = bulk('Al', 'bcc', a=3.0)
        cutoffs = [3.0]
        orbit_list = OrbitList(
            structure, cutoffs,
            symprec=self.symprec, position_tolerance=self.position_tolerance,
            fractional_position_tolerance=self.fractional_position_tolerance)
        # one singlet and two pairs expected
        self.assertEqual(len(orbit_list), 3)
        # singlet
        singlet = orbit_list.get_orbit(0)
        self.assertEqual(len(singlet), 1)
        # first pair has multiplicity equal to 4
        pairs = orbit_list.get_orbit(1)
        self.assertEqual(len(pairs), 4)
        # first pair has multiplicity equal to 3
        pairs = orbit_list.get_orbit(2)
        self.assertEqual(len(pairs), 3)
        # not more orbits listed
        with self.assertRaises(IndexError):
            orbit_list.get_orbit(3)

    def test_orbit_list_hcp(self):
        """
        Tests orbit list has the right number of singlet and pairs for
        a hcp structure.
        """
        structure = bulk('Ni', 'hcp', a=3.0)
        cutoffs = [3.1]
        orbit_list = OrbitList(
            structure, cutoffs,
            symprec=self.symprec, position_tolerance=self.position_tolerance,
            fractional_position_tolerance=self.fractional_position_tolerance)
        # only one singlet and one pair expected
        self.assertEqual(len(orbit_list), 3)
        # singlet
        singlet = orbit_list.get_orbit(0)
        self.assertEqual(len(singlet), 2)
        # pair has multiplicity equal to 4
        pairs = orbit_list.get_orbit(1)
        self.assertEqual(len(pairs), 6)
        # pair has multiplicity equal to 4
        pairs = orbit_list.get_orbit(2)
        self.assertEqual(len(pairs), 6)
        # not more orbits listed
        with self.assertRaises(IndexError):
            orbit_list.get_orbit(3)

    def test_remove_orbit(self):
        """Tests removing orbits by index."""
        current_size = len(self.orbit_list)

        for i in sorted(range(current_size), reverse=True):
            self.orbit_list.remove_orbit(i)
            current_size -= 1
            self.assertEqual(len(self.orbit_list), current_size)
class ClusterSpace(_ClusterSpace):
    """This class provides functionality for generating and maintaining
    cluster spaces.

    **Note:** In icet all :class:`ase.Atoms` objects must have
    periodic boundary conditions. When carrying out cluster expansions
    for surfaces and nanoparticles it is therefore recommended to
    surround the structure with vacuum and use periodic boundary
    conditions. This can be done using e.g., :func:`ase.Atoms.center`.

    Parameters
    ----------
    structure : ase.Atoms
        atomic structure
    cutoffs : list(float)
        cutoff radii per order that define the cluster space

        Cutoffs are specified in units of Angstrom and refer to the
        longest distance between two atoms in the cluster. The first
        element refers to pairs, the second to triplets, the third
        to quadruplets, and so on. ``cutoffs=[7.0, 4.5]`` thus implies
        that all pairs distanced 7 A or less will be included,
        as well as all triplets among which the longest distance is no
        longer than 4.5 A.
    chemical_symbols : list(str) or list(list(str))
        list of chemical symbols, each of which must map to an element
        of the periodic table

        If a list of chemical symbols is provided, all sites on the
        lattice will have the same allowed occupations as the input
        list.

        If a list of list of chemical symbols is provided then the
        outer list must be the same length as the `structure` object and
        ``chemical_symbols[i]`` will correspond to the allowed species
        on lattice site ``i``.
    symprec : float
        tolerance imposed when analyzing the symmetry using spglib
    position_tolerance : float
        tolerance applied when comparing positions in Cartesian coordinates

    Examples
    --------
    The following snippets illustrate several common situations::

        >>> from ase.build import bulk
        >>> from ase.io import read
        >>> from icet import ClusterSpace

        >>> # AgPd alloy with pairs up to 7.0 A and triplets up to 4.5 A
        >>> prim = bulk('Ag')
        >>> cs = ClusterSpace(structure=prim, cutoffs=[7.0, 4.5],
        ...                   chemical_symbols=[['Ag', 'Pd']])
        >>> print(cs)

        >>> # (Mg,Zn)O alloy on rocksalt lattice with pairs up to 8.0 A
        >>> prim = bulk('MgO', crystalstructure='rocksalt', a=6.0)
        >>> cs = ClusterSpace(structure=prim, cutoffs=[8.0],
        ...                   chemical_symbols=[['Mg', 'Zn'], ['O']])
        >>> print(cs)

        >>> # (Ga,Al)(As,Sb) alloy with pairs, triplets, and quadruplets
        >>> prim = bulk('GaAs', crystalstructure='zincblende', a=6.5)
        >>> cs = ClusterSpace(structure=prim, cutoffs=[7.0, 6.0, 5.0],
        ...                   chemical_symbols=[['Ga', 'Al'], ['As', 'Sb']])
        >>> print(cs)

        >>> # PdCuAu alloy with pairs and triplets
        >>> prim = bulk('Pd')
        >>> cs = ClusterSpace(structure=prim, cutoffs=[7.0, 5.0],
        ...                   chemical_symbols=[['Au', 'Cu', 'Pd']])
        >>> print(cs)

    """

    def __init__(self,
                 structure: Atoms,
                 cutoffs: List[float],
                 chemical_symbols: Union[List[str], List[List[str]]],
                 symprec: float = 1e-5,
                 position_tolerance: float = None) -> None:

        if not isinstance(structure, Atoms):
            raise TypeError('Input configuration must be an ASE Atoms object'
                            ', not type {}'.format(type(structure)))
        if not all(structure.pbc):
            raise ValueError('Input structure must have periodic boundary conditions')
        if symprec <= 0:
            raise ValueError('symprec must be a positive number')

        self._config = {'symprec': symprec}
        self._cutoffs = cutoffs.copy()
        self._input_structure = structure.copy()
        self._input_chemical_symbols = copy.deepcopy(chemical_symbols)
        chemical_symbols = self._get_chemical_symbols()

        self._pruning_history = []

        # set up primitive
        occupied_primitive, primitive_chemical_symbols = get_occupied_primitive_structure(
            self._input_structure, chemical_symbols, symprec=self.symprec)
        self._primitive_chemical_symbols = primitive_chemical_symbols
        assert len(occupied_primitive) == len(primitive_chemical_symbols)

        # derived tolerances
        if position_tolerance is None:
            self._config['position_tolerance'] = symprec
        else:
            if position_tolerance <= 0:
                raise ValueError('position_tolerance must be a positive number')
            self._config['position_tolerance'] = position_tolerance
        effective_box_size = abs(np.linalg.det(occupied_primitive.cell)) ** (1 / 3)
        tol = self.position_tolerance / effective_box_size
        tol = min(tol, self._config['position_tolerance'] / 5)
        self._config['fractional_position_tolerance'] = round(tol, -int(floor(log10(abs(tol)))))

        # set up orbit list
        self._orbit_list = OrbitList(
            structure=occupied_primitive,
            cutoffs=self._cutoffs,
            symprec=self.symprec,
            position_tolerance=self.position_tolerance,
            fractional_position_tolerance=self.fractional_position_tolerance)
        self._orbit_list.remove_inactive_orbits(primitive_chemical_symbols)

        # call (base) C++ constructor
        _ClusterSpace.__init__(self,
                               chemical_symbols=primitive_chemical_symbols,
                               orbit_list=self._orbit_list,
                               position_tolerance=self.position_tolerance,
                               fractional_position_tolerance=self.fractional_position_tolerance)

    def _get_chemical_symbols(self):
        """ Returns chemical symbols using input structure and
        chemical symbols. Carries out multiple sanity checks. """

        # setup chemical symbols as List[List[str]]
        if all(isinstance(i, str) for i in self._input_chemical_symbols):
            chemical_symbols = [
                self._input_chemical_symbols] * len(self._input_structure)
        elif not all(isinstance(i, list) for i in self._input_chemical_symbols):
            raise TypeError("chemical_symbols must be List[str] or List[List[str]], not {}".format(
                type(self._input_chemical_symbols)))
        elif len(self._input_chemical_symbols) != len(self._input_structure):
            msg = 'chemical_symbols must have same length as structure. '
            msg += 'len(chemical_symbols) = {}, len(structure)= {}'.format(
                len(self._input_chemical_symbols), len(self._input_structure))
            raise ValueError(msg)
        else:
            chemical_symbols = copy.deepcopy(self._input_chemical_symbols)

        for i, symbols in enumerate(chemical_symbols):
            if len(symbols) != len(set(symbols)):
                raise ValueError(
                    'Found duplicates of allowed chemical symbols on site {}.'
                    ' allowed species on  site {}= {}'.format(i, i, symbols))

        if len([tuple(sorted(s)) for s in chemical_symbols if len(s) > 1]) == 0:
            raise ValueError('No active sites found')

        return chemical_symbols

    def _get_chemical_symbol_representation(self):
        """Returns a str version of the chemical symbols that is
        easier on the eyes.
        """
        sublattices = self.get_sublattices(self.primitive_structure)
        nice_str = []
        for sublattice in sublattices.active_sublattices:
            sublattice_symbol = sublattice.symbol

            nice_str.append('{} (sublattice {})'.format(
                list(sublattice.chemical_symbols), sublattice_symbol))
        return ', '.join(nice_str)

    def _get_string_representation(self,
                                   print_threshold: int = None,
                                   print_minimum: int = 10) -> str:
        """
        String representation of the cluster space that provides an overview of
        the orbits (order, radius, multiplicity etc) that constitute the space.

        Parameters
        ----------
        print_threshold
            if the number of orbits exceeds this number print dots
        print_minimum
            number of lines printed from the top and the bottom of the orbit
            list if `print_threshold` is exceeded

        Returns
        -------
        multi-line string
            string representation of the cluster space.
        """

        def repr_orbit(orbit, header=False):
            formats = {'order': '{:2}',
                       'radius': '{:8.4f}',
                       'multiplicity': '{:4}',
                       'index': '{:4}',
                       'orbit_index': '{:4}',
                       'multi_component_vector': '{:}',
                       'sublattices': '{:}'}
            s = []
            for name, value in orbit.items():
                str_repr = formats[name].format(value)
                n = max(len(name), len(str_repr))
                if header:
                    s += ['{s:^{n}}'.format(s=name, n=n)]
                else:
                    s += ['{s:^{n}}'.format(s=str_repr, n=n)]
            return ' | '.join(s)

        # basic information
        # (use largest orbit to obtain maximum line length)
        prototype_orbit = self.orbit_data[-1]
        width = len(repr_orbit(prototype_orbit))
        s = []  # type: List
        s += ['{s:=^{n}}'.format(s=' Cluster Space ', n=width)]
        s += [' {:38} : {}'.format('space group', self.space_group)]
        s += [' {:38} : {}'
              .format('chemical species', self._get_chemical_symbol_representation())]
        s += [' {:38} : {}'.format('cutoffs', ' '
                                   .join(['{:.4f}'.format(co) for co in self._cutoffs]))]
        s += [' {:38} : {}'.format('total number of parameters', len(self))]
        t = ['{}= {}'.format(k, c)
             for k, c in self.get_number_of_orbits_by_order().items()]
        s += [' {:38} : {}'.format('number of parameters by order', '  '.join(t))]
        for key, value in sorted(self._config.items()):
            s += [' {:38} : {}'.format(key, value)]

        # table header
        s += [''.center(width, '-')]
        s += [repr_orbit(prototype_orbit, header=True)]
        s += [''.center(width, '-')]

        # table body
        index = 0
        orbit_list_info = self.orbit_data
        while index < len(orbit_list_info):
            if (print_threshold is not None and
                    len(self) > print_threshold and
                    index >= print_minimum and
                    index <= len(self) - print_minimum):
                index = len(self) - print_minimum
                s += [' ...']
            s += [repr_orbit(orbit_list_info[index])]
            index += 1
        s += [''.center(width, '=')]

        return '\n'.join(s)

    def __repr__(self) -> str:
        """ String representation. """
        return self._get_string_representation(print_threshold=50)

    def print_overview(self,
                       print_threshold: int = None,
                       print_minimum: int = 10) -> None:
        """
        Print an overview of the cluster space in terms of the orbits (order,
        radius, multiplicity etc).

        Parameters
        ----------
        print_threshold
            if the number of orbits exceeds this number print dots
        print_minimum
            number of lines printed from the top and the bottom of the orbit
            list if `print_threshold` is exceeded
        """
        print(self._get_string_representation(print_threshold=print_threshold,
                                              print_minimum=print_minimum))

    @property
    def symprec(self) -> float:
        """ tolerance imposed when analyzing the symmetry using spglib """
        return self._config['symprec']

    @property
    def position_tolerance(self) -> float:
        """ tolerance applied when comparing positions in Cartesian coordinates """
        return self._config['position_tolerance']

    @property
    def fractional_position_tolerance(self) -> float:
        """ tolerance applied when comparing positions in fractional coordinates """
        return self._config['fractional_position_tolerance']

    @property
    def space_group(self) -> str:
        """ space group of the primitive structure in international notion (via spglib) """
        structure_as_tuple = ase_atoms_to_spglib_cell(self.primitive_structure)
        return spglib.get_spacegroup(structure_as_tuple, symprec=self._config['symprec'])

    @property
    def orbit_data(self) -> List[OrderedDict]:
        """
        list of orbits with information regarding
        order, radius, multiplicity etc
        """
        data = []
        zerolet = OrderedDict([('index', 0),
                               ('order', 0),
                               ('radius', 0),
                               ('multiplicity', 1),
                               ('orbit_index', -1),
                               ('multi_component_vector', '.'),
                               ('sublattices', '.')])
        sublattices = self.get_sublattices(self.primitive_structure)
        data.append(zerolet)
        index = 1
        while index < len(self):
            multi_component_vectors_by_orbit = self.get_multi_component_vectors_by_orbit(index)
            orbit_index = multi_component_vectors_by_orbit[0]
            mc_vector = multi_component_vectors_by_orbit[1]
            orbit = self.get_orbit(orbit_index)
            repr_sites = orbit.sites_of_representative_cluster
            orbit_sublattices = '-'.join(
                [sublattices[sublattices.get_sublattice_index(ls.index)].symbol
                 for ls in repr_sites])
            local_Mi = self.get_number_of_allowed_species_by_site(
                self._get_primitive_structure(), orbit.sites_of_representative_cluster)
            mc_vectors = orbit.get_mc_vectors(local_Mi)
            mc_permutations = self.get_multi_component_vector_permutations(
                mc_vectors, orbit_index)
            mc_index = mc_vectors.index(mc_vector)
            mc_permutations_multiplicity = len(mc_permutations[mc_index])
            cluster = self.get_orbit(orbit_index).representative_cluster

            multiplicity = len(self.get_orbit(
                orbit_index).equivalent_clusters)
            record = OrderedDict([('index', index),
                                  ('order', cluster.order),
                                  ('radius', cluster.radius),
                                  ('multiplicity', multiplicity *
                                   mc_permutations_multiplicity),
                                  ('orbit_index', orbit_index)])
            record['multi_component_vector'] = mc_vector
            record['sublattices'] = orbit_sublattices
            data.append(record)
            index += 1
        return data

    def get_number_of_orbits_by_order(self) -> OrderedDict:
        """
        Returns the number of orbits by order.

        Returns
        -------
        an ordered dictionary where keys and values represent order and number
        of orbits, respectively
        """
        count_orbits = {}  # type: dict[int, int]
        for orbit in self.orbit_data:
            k = orbit['order']
            count_orbits[k] = count_orbits.get(k, 0) + 1
        return OrderedDict(sorted(count_orbits.items()))

    def get_cluster_vector(self, structure: Atoms) -> np.ndarray:
        """
        Returns the cluster vector for a structure.

        Parameters
        ----------
        structure
            atomic configuration

        Returns
        -------
        the cluster vector
        """
        if not isinstance(structure, Atoms):
            raise TypeError('Input structure must be an ASE Atoms object')

        try:
            cv = _ClusterSpace.get_cluster_vector(
                self,
                structure=Structure.from_atoms(structure),
                fractional_position_tolerance=self.fractional_position_tolerance)
        except Exception as e:
            self.assert_structure_compatibility(structure)
            raise(e)
        return cv

    def get_coordinates_of_representative_cluster(self, orbit_index: int) -> List[Tuple[float]]:
        """
        Returns the positions of atoms in the selected orbit

        Parameters
        ----------
        orbit_index
            index of the orbit from which to calculate the positions of the atoms

        Returns
        -------
        list of positions of atoms in the selected orbit

        """
        # Raise exception if chosen orbit index not in current list of orbit indices
        if not (orbit_index in range(len(self._orbit_list))):
            raise ValueError('The input orbit index is not in the list of possible values.')

        lattice_sites = self._orbit_list.get_orbit(orbit_index).sites_of_representative_cluster
        positions = []

        for site in lattice_sites:
            pos = get_position_from_lattice_site(structure=self.primitive_structure,
                                                 lattice_site=site)
            positions.append(pos)

        return positions

    def _prune_orbit_list(self, indices: List[int]) -> None:
        """
        Prunes the internal orbit list

        Parameters
        ----------
        indices
            indices to all orbits to be removed
        """
        size_before = len(self._orbit_list)

        self._prune_orbit_list_cpp(indices)
        for index in sorted(indices, reverse=True):
            self._orbit_list.remove_orbit(index)
        self._compute_multi_component_vectors()

        size_after = len(self._orbit_list)
        assert size_before - len(indices) == size_after
        self._pruning_history.append(indices)

    @property
    def primitive_structure(self) -> Atoms:
        """ Primitive structure on which cluster space is based """
        structure = self._get_primitive_structure().to_atoms()
        # Decorate with the "real" symbols (instead of H, He, Li etc)
        for atom, symbols in zip(structure, self._primitive_chemical_symbols):
            atom.symbol = min(symbols)
        return structure

    @property
    def chemical_symbols(self) -> List[List[str]]:
        """ Species identified by their chemical symbols """
        return self._primitive_chemical_symbols.copy()

    @property
    def cutoffs(self) -> List[float]:
        """
        Cutoffs for different n-body clusters. The cutoff radius (in
        Angstroms) defines the largest interatomic distance in a
        cluster.
        """
        return self._cutoffs

    @property
    def orbit_list(self):
        """Orbit list that defines the cluster in the cluster space"""
        return self._orbit_list

    def get_possible_orbit_occupations(self, orbit_index: int) \
            -> List[List[str]]:
        """Returns possible occupation of the orbit.

        Parameters
        ----------
        orbit_index
        """
        orbit = self.orbit_list.orbits[orbit_index]

        indices = [
            lattice_site.index for lattice_site in orbit.sites_of_representative_cluster]

        allowed_species = [self.chemical_symbols[index] for index in indices]

        return list(itertools.product(*allowed_species))

    def get_sublattices(self, structure: Atoms) -> Sublattices:
        """
        Returns the sublattices of the input structure.

        Parameters
        ----------
        structure
            structure the sublattices are based on
        """
        sl = Sublattices(self.chemical_symbols,
                         self.primitive_structure,
                         structure,
                         fractional_position_tolerance=self.fractional_position_tolerance)
        return sl

    def assert_structure_compatibility(self, structure: Atoms, vol_tol: float = 1e-5) -> None:
        """ Raises error if structure is not compatible with ClusterSpace.

        Todo
        ----
        Add check for if structure is relaxed.

        Parameters
        ----------
        structure
            structure to check if compatible with ClusterSpace
        """
        # check volume
        prim = self.primitive_structure
        vol1 = prim.get_volume() / len(prim)
        vol2 = structure.get_volume() / len(structure)
        if abs(vol1 - vol2) > vol_tol:
            raise ValueError('Volume per atom of structure does not match the volume of '
                             'ClusterSpace.primitive_structure')

        # check occupations
        sublattices = self.get_sublattices(structure)
        sublattices.assert_occupation_is_allowed(structure.get_chemical_symbols())

        # check pbc
        if not all(structure.pbc):
            raise ValueError('Input structure must have periodic boundary conditions')

    def is_supercell_self_interacting(self, structure: Atoms) -> bool:
        """
        Checks whether an structure has self-interactions via periodic
        boundary conditions.

        Parameters
        ----------
        structure
            structure to be tested

        Returns
        -------
        bool
            If True, the structure contains self-interactions via periodic
            boundary conditions, otherwise False.
        """
        ol = self.orbit_list.get_supercell_orbit_list(
            structure=structure,
            fractional_position_tolerance=self.fractional_position_tolerance)
        orbit_indices = set()
        for orbit in ol.orbits:
            for sites in orbit.equivalent_clusters:
                indices = tuple(sorted([site.index for site in sites]))
                if indices in orbit_indices:
                    return True
                else:
                    orbit_indices.add(indices)
        return False

    def write(self, filename: str) -> None:
        """
        Saves cluster space to a file.

        Parameters
        ---------
        filename
            name of file to which to write
        """

        with tarfile.open(name=filename, mode='w') as tar_file:

            # write items
            items = dict(cutoffs=self._cutoffs,
                         chemical_symbols=self._input_chemical_symbols,
                         pruning_history=self._pruning_history,
                         symprec=self.symprec,
                         position_tolerance=self.position_tolerance)
            temp_file = tempfile.TemporaryFile()
            pickle.dump(items, temp_file)
            temp_file.seek(0)
            tar_info = tar_file.gettarinfo(arcname='items', fileobj=temp_file)
            tar_file.addfile(tar_info, temp_file)
            temp_file.close()

            # write structure
            temp_file = tempfile.NamedTemporaryFile()
            ase_write(temp_file.name, self._input_structure, format='json')
            temp_file.seek(0)
            tar_info = tar_file.gettarinfo(arcname='atoms', fileobj=temp_file)
            tar_file.addfile(tar_info, temp_file)
            temp_file.close()

    @staticmethod
    def read(filename: str):
        """
        Reads cluster space from filename.

        Parameters
        ---------
        filename
            name of file from which to read cluster space
        """
        if isinstance(filename, str):
            tar_file = tarfile.open(mode='r', name=filename)
        else:
            tar_file = tarfile.open(mode='r', fileobj=filename)

        # read items
        items = pickle.load(tar_file.extractfile('items'))

        # read structure
        temp_file = tempfile.NamedTemporaryFile()
        temp_file.write(tar_file.extractfile('atoms').read())
        temp_file.seek(0)
        structure = ase_read(temp_file.name, format='json')

        tar_file.close()

        # ensure backward compatibility
        if 'symprec' not in items:  # pragma: no cover
            items['symprec'] = 1e-5
        if 'position_tolerance' not in items:  # pragma: no cover
            items['position_tolerance'] = items['symprec']

        cs = ClusterSpace(structure=structure,
                          cutoffs=items['cutoffs'],
                          chemical_symbols=items['chemical_symbols'],
                          symprec=items['symprec'],
                          position_tolerance=items['position_tolerance'])
        for indices in items['pruning_history']:
            cs._prune_orbit_list(indices)
        return cs

    def copy(self):
        """ Returns copy of ClusterSpace instance. """
        cs_copy = ClusterSpace(structure=self._input_structure,
                               cutoffs=self.cutoffs,
                               chemical_symbols=self._input_chemical_symbols,
                               symprec=self.symprec,
                               position_tolerance=self.position_tolerance)
        for indices in self._pruning_history:
            cs_copy._prune_orbit_list(indices)
        return cs_copy
# Start import
from ase.build import bulk
from icet.core.cluster_counts import ClusterCounts
from icet.core.orbit_list import OrbitList
# End import

# Create a titanium, single-layered, sheet and randomly populate some of the
# sites with W atoms.
# Start setup
prim_structure = bulk('Ti', 'sc', a=3.0)
structure = prim_structure.repeat([2, 1, 1])
structure.set_chemical_symbols(['Ti', 'W'])
cutoffs = [5.0]
# End setup

# Determine the orbit list for the corresponding primitive structure for all
# pair clusters within the cutoff distance
symprec = 1e-5  # tolerance used by spglib
position_tolerance = 1e-5  # tolerance used when comparing positions
fractional_position_tolerance = position_tolerance / 3  # ... in fractional coordinates
prim_orbitlist = OrbitList(prim_structure, cutoffs, symprec,
                           position_tolerance, fractional_position_tolerance)
# Use the primitive orbit list to count the number of clusters.
cluster_counts = ClusterCounts(prim_orbitlist, structure,
                               fractional_position_tolerance)
# Print all of the clusters that were found.
print('Number of atoms: {0}'.format(len(structure)))
print('Found {} orbits'.format(len(cluster_counts)))
print(cluster_counts)
class TestLocalOrbitListGenerator(unittest.TestCase):
    """Container for test of class functionality."""

    def __init__(self, *args, **kwargs):
        super(TestLocalOrbitListGenerator, self).__init__(*args, **kwargs)
        self.symprec = 1e-5
        self.position_tolerance = 1e-5
        self.fractional_position_tolerance = 1e-6

    def shortDescription(self):
        """Silences unittest from printing the docstrings in test cases."""
        return None

    def setUp(self):
        """Instantiate class for each test case."""
        prim_structure = bulk('Al')
        cutoffs = [4.2, 4.2]
        self.orbit_list = OrbitList(
            prim_structure, cutoffs,
            symprec=self.symprec, position_tolerance=self.position_tolerance,
            fractional_position_tolerance=self.fractional_position_tolerance)
        self.primitive = self.orbit_list.get_primitive_structure()
        super_structure = make_supercell(prim_structure, [[2, 0, 1000],
                                                          [0, 2, 0],
                                                          [0, 0, 2]])
        self.supercell = Structure.from_atoms(super_structure)

        self.lolg = LocalOrbitListGenerator(
            self.orbit_list, self.supercell,
            fractional_position_tolerance=self.fractional_position_tolerance)

    def test_generate_local_orbit_list_from_index(self):
        """
        Tests that function generates an orbit list from
        an index of a specific offset of the primitive structure.
        """
        unique_offsets = self.lolg._get_unique_primcell_offsets()

        for index, offset in enumerate(unique_offsets):
            local_orbit_list = self.lolg.generate_local_orbit_list(index)
            for orbit_prim, orbit_super in zip(self.orbit_list.orbits,
                                               local_orbit_list.orbits):
                for site_p, site_s in zip(orbit_prim.sites_of_representative_cluster,
                                          orbit_super.sites_of_representative_cluster):
                    site_p.unitcell_offset += offset
                    pos_super = self.supercell.get_position(site_s)
                    pos_prim = self.primitive.get_position(site_p)
                    self.assertTrue(np.all(np.isclose(pos_super, pos_prim)))

    def test_generate_local_orbit_list_from_offset(self):
        """
        Tests that function generates an orbit list for the given
        offset of the primitive structure.
        """
        unique_offsets = self.lolg._get_unique_primcell_offsets()
        for offset in unique_offsets:
            local_orbit_list = self.lolg.generate_local_orbit_list(offset)
            for orbit_prim, orbit_super in zip(self.orbit_list.orbits,
                                               local_orbit_list.orbits):
                for site_p, site_s in zip(orbit_prim.sites_of_representative_cluster,
                                          orbit_super.sites_of_representative_cluster):
                    site_p.unitcell_offset += offset
                    pos_super = self.supercell.get_position(site_s)
                    pos_prim = self.primitive.get_position(site_p)
                    self.assertTrue(np.all(np.isclose(pos_super, pos_prim)))

    def test_generating_full_orbit_list_with_primitive(self):
        """
        Tests creating a full orbit list using the primitive as the supercell.
        """
        prim_structure = bulk('Al')
        prim = Structure.from_atoms(prim_structure)
        lolg = LocalOrbitListGenerator(
            self.orbit_list, prim,
            fractional_position_tolerance=self.fractional_position_tolerance)
        lolg.generate_full_orbit_list()

    def test_generate_full_orbit_list(self):
        """
        Tests that equivalent sites of all local orbit lists are listed
        as equivalent sites in the full orbit list.
        """
        fol = self.lolg.generate_full_orbit_list()
        for offset in self.lolg._get_unique_primcell_offsets():
            lol = self.lolg.generate_local_orbit_list(offset)
            for orbit, orbit_ in zip(lol.orbits, fol.orbits):
                for sites in orbit.equivalent_clusters:
                    self.assertIn(sites, orbit_.equivalent_clusters)

    def test_clear(self):
        """
        Tests vector of unique offsets and primitive to supercell mapping
        are cleared.
        """

        self.lolg.generate_local_orbit_list(2)
        self.lolg.clear()
        offsets_count = self.lolg.get_number_of_unique_offsets()
        self.assertEqual(offsets_count, 0)
        mapping = self.lolg._get_primitive_to_supercell_map()
        self.assertEqual(len(mapping), 0)

    def test_unique_offset_count(self):
        """
        Tests number of unique offsets corresponds to the number of atoms
        in the supercell given that there is one atom in the primitive cell.
        """
        self.assertEqual(self.lolg.get_number_of_unique_offsets(),
                         len(self.supercell))

    def test_get_primitive_to_supercell_map(self):
        """Tests primitive to supercell mapping."""
        unique_offsets = self.lolg._get_unique_primcell_offsets()

        for offset in unique_offsets:
            self.lolg.generate_local_orbit_list(offset)
            mapping = self.lolg._get_primitive_to_supercell_map()
            for sites_prim, sites_super in mapping.items():
                pos_super = self.supercell.get_position(sites_super)
                pos_prim = self.primitive.get_position(sites_prim)
                self.assertTrue(np.all(np.isclose(pos_super, pos_prim)))

    def test_unique_primcell_offsets(self):
        """
        Tests primitive offsets are unique and take to positions that
        match atoms positions in the supercell.
        """
        unique_offsets = self.lolg._get_unique_primcell_offsets()
        super_pos = self.supercell.positions

        for k, offset in enumerate(unique_offsets):
            pos_prim = self.primitive.get_position(LatticeSite(0, offset))
            self.assertTrue(
                np.any(np.isclose(pos_prim, pos) for pos in super_pos))
            for i in range(k + 1, len(unique_offsets)):
                self.assertFalse(np.all(np.isclose(offset, unique_offsets[i])))
from icet.core.orbit_list import OrbitList
from ase.build import bulk
import time

if __name__ == '__main__':

    structure = bulk('Al')
    cutoffs = [10, 7, 6]
    symprec = 1e-5
    position_tolerance = 1e-5
    fractional_position_tolerance = 2e-6

    t = time.process_time()
    orbit = OrbitList(
        structure=structure,
        cutoffs=cutoffs,
        symprec=symprec,
        position_tolerance=position_tolerance,
        fractional_position_tolerance=fractional_position_tolerance)  # noqa
    elapsed_time = time.process_time() - t

    print('Time to initialize OrbitList with cutoffs: {}, {:.6} sec'.format(
        cutoffs, elapsed_time))