class TestClusterExpansion(unittest.TestCase):
    """Container for tests of the class functionality."""
    def __init__(self, *args, **kwargs):
        super(TestClusterExpansion, self).__init__(*args, **kwargs)
        self.structure = bulk('Au')
        self.cutoffs = [3.0] * 3
        chemical_symbols = ['Au', 'Pd']
        self.cs = ClusterSpace(self.structure, self.cutoffs, chemical_symbols)

    def shortDescription(self):
        """Silences unittest from printing the docstrings in test cases."""
        return None

    def setUp(self):
        """Setup before each test."""
        params_len = len(self.cs)
        self.parameters = np.arange(params_len)
        self.ce = ClusterExpansion(self.cs, self.parameters)

    def test_init(self):
        """Tests that initialization works."""
        self.assertIsInstance(self.ce, ClusterExpansion)

        # test whether method raises Exception
        with self.assertRaises(ValueError) as context:
            ClusterExpansion(self.cs, [0.0])
        self.assertTrue('cluster_space (5) and parameters (1) must have the'
                        ' same length' in str(context.exception))

    def test_predict(self):
        """Tests predict function."""
        predicted_val = self.ce.predict(self.structure)
        self.assertEqual(predicted_val, 10.0)

    def test_property_orders(self):
        """Tests orders property."""
        self.assertEqual(self.ce.orders, list(range(len(self.cutoffs) + 2)))

    def test_property_to_dataframe(self):
        """Tests to_dataframe() property."""
        df = self.ce.to_dataframe()
        self.assertIn('radius', df.columns)
        self.assertIn('order', df.columns)
        self.assertIn('eci', df.columns)
        self.assertEqual(len(df), len(self.parameters))

    def test_get__clusterspace_copy(self):
        """Tests get cluster space copy."""
        self.assertEqual(str(self.ce.get_cluster_space_copy()), str(self.cs))

    def test_property_parameters(self):
        """Tests parameters properties."""
        self.assertEqual(list(self.ce.parameters), list(self.parameters))

    def test_len(self):
        """Tests len functionality."""
        self.assertEqual(self.ce.__len__(), len(self.parameters))

    def test_read_write(self):
        """Tests read and write functionalities."""
        # save to file
        temp_file = tempfile.NamedTemporaryFile()
        self.ce.write(temp_file.name)

        # read from file
        temp_file.seek(0)
        ce_read = ClusterExpansion.read(temp_file.name)

        # check cluster space
        self.assertEqual(self.cs._input_structure,
                         ce_read._cluster_space._input_structure)
        self.assertEqual(self.cs._cutoffs, ce_read._cluster_space._cutoffs)
        self.assertEqual(self.cs._input_chemical_symbols,
                         ce_read._cluster_space._input_chemical_symbols)

        # check parameters
        self.assertIsInstance(ce_read.parameters, np.ndarray)
        self.assertEqual(list(ce_read.parameters), list(self.parameters))

        # check metadata
        self.assertEqual(len(self.ce.metadata), len(ce_read.metadata))
        self.assertSequenceEqual(sorted(self.ce.metadata.keys()),
                                 sorted(ce_read.metadata.keys()))
        for key in self.ce.metadata.keys():
            self.assertEqual(self.ce.metadata[key], ce_read.metadata[key])

    def test_read_write_pruned(self):
        """Tests read and write functionalities."""
        # save to file
        temp_file = tempfile.NamedTemporaryFile()
        self.ce.prune(indices=[2, 3])
        self.ce.prune(tol=3)
        pruned_params = self.ce.parameters
        pruned_cs_len = len(self.ce._cluster_space)
        self.ce.write(temp_file.name)

        # read from file
        temp_file.seek(0)
        ce_read = ClusterExpansion.read(temp_file.name)
        params_read = ce_read.parameters
        cs_len_read = len(ce_read._cluster_space)

        # check cluster space
        self.assertEqual(cs_len_read, pruned_cs_len)
        self.assertEqual(list(params_read), list(pruned_params))

    def test_prune_cluster_expansion(self):
        """Tests pruning cluster expansion."""
        len_before = len(self.ce)
        self.ce.prune()
        len_after = len(self.ce)
        self.assertEqual(len_before, len_after)

        # Set all parameters to zero except three
        self.ce._parameters = np.array([0.0] * len_after)
        self.ce._parameters[0] = 1.0
        self.ce._parameters[1] = 2.0
        self.ce._parameters[2] = 0.5
        self.ce.prune()
        self.assertEqual(len(self.ce), 3)
        self.assertNotEqual(len(self.ce), len_after)

    def test_prune_cluster_expansion_tol(self):
        """Tests pruning cluster expansion with tolerance."""
        len_before = len(self.ce)
        self.ce.prune()
        len_after = len(self.ce)
        self.assertEqual(len_before, len_after)

        # Set all parameters to zero except two, one of which is
        # non-zero but below the tolerance
        self.ce._parameters = np.array([0.0] * len_after)
        self.ce._parameters[0] = 1.0
        self.ce._parameters[1] = 0.01
        self.ce.prune(tol=0.02)
        self.assertEqual(len(self.ce), 1)
        self.assertNotEqual(len(self.ce), len_after)

    def test_prune_pairs(self):
        """Tests pruning pairs only."""
        df = self.ce.to_dataframe()
        pair_indices = df.index[df['order'] == 2].tolist()
        self.ce.prune(indices=pair_indices)

        df_new = self.ce.to_dataframe()
        pair_indices_new = df_new.index[df_new['order'] == 2].tolist()
        self.assertEqual(pair_indices_new, [])

    def test_prune_zerolet(self):
        """Tests pruning zerolet."""
        with self.assertRaises(ValueError) as context:
            self.ce.prune(indices=[0])
        self.assertTrue('zerolet may not be pruned' in str(context.exception))

    def test_plot_ecis(self):
        """Tests plot_ecis."""
        self.ce.plot_ecis()

    def test_repr(self):
        """Tests repr functionality."""
        retval = self.ce.__repr__()
        target = """
================================================ Cluster Expansion =================================================
 space group                            : Fm-3m (225)
 chemical species                       : ['Au', 'Pd'] (sublattice A)
 cutoffs                                : 3.0000 3.0000 3.0000
 total number of parameters             : 5
 number of parameters by order          : 0= 1  1= 1  2= 1  3= 1  4= 1
 fractional_position_tolerance          : 2e-06
 position_tolerance                     : 1e-05
 symprec                                : 1e-05
 total number of nonzero parameters     : 4
 number of nonzero parameters by order  : 0= 0  1= 1  2= 1  3= 1  4= 1
--------------------------------------------------------------------------------------------------------------------
index | order |  radius  | multiplicity | orbit_index | multi_component_vector | sublattices | parameter |    ECI
--------------------------------------------------------------------------------------------------------------------
   0  |   0   |   0.0000 |        1     |      -1     |           .            |      .      |         0 |         0
   1  |   1   |   0.0000 |        1     |       0     |          [0]           |      A      |         1 |         1
   2  |   2   |   1.4425 |        6     |       1     |         [0, 0]         |     A-A     |         2 |     0.333
   3  |   3   |   1.6657 |        8     |       2     |       [0, 0, 0]        |    A-A-A    |         3 |     0.375
   4  |   4   |   1.7667 |        2     |       3     |      [0, 0, 0, 0]      |   A-A-A-A   |         4 |         2
====================================================================================================================
"""  # noqa

        self.assertEqual(strip_surrounding_spaces(target),
                         strip_surrounding_spaces(retval))

    def test_get_string_representation(self):
        """Tests _get_string_representation functionality."""
        retval = self.ce._get_string_representation(print_threshold=2,
                                                    print_minimum=1)
        target = """
================================================ Cluster Expansion =================================================
 space group                            : Fm-3m (225)
 chemical species                       : ['Au', 'Pd'] (sublattice A)
 cutoffs                                : 3.0000 3.0000 3.0000
 total number of parameters             : 5
 number of parameters by order          : 0= 1  1= 1  2= 1  3= 1  4= 1
 fractional_position_tolerance          : 2e-06
 position_tolerance                     : 1e-05
 symprec                                : 1e-05
 total number of nonzero parameters     : 4
 number of nonzero parameters by order  : 0= 0  1= 1  2= 1  3= 1  4= 1
--------------------------------------------------------------------------------------------------------------------
index | order |  radius  | multiplicity | orbit_index | multi_component_vector | sublattices | parameter |    ECI
--------------------------------------------------------------------------------------------------------------------
   0  |   0   |   0.0000 |        1     |      -1     |           .            |      .      |         0 |         0
 ...
   4  |   4   |   1.7667 |        2     |       3     |      [0, 0, 0, 0]      |   A-A-A-A   |         4 |         2
====================================================================================================================
"""  # noqa
        self.assertEqual(strip_surrounding_spaces(target),
                         strip_surrounding_spaces(retval))

    def test_print_overview(self):
        """Tests print_overview functionality."""
        with StringIO() as capturedOutput:
            sys.stdout = capturedOutput  # redirect stdout
            self.ce.print_overview()
            sys.stdout = sys.__stdout__  # reset redirect
            self.assertTrue('Cluster Expansion' in capturedOutput.getvalue())
class TestClusterExpansionTernary(unittest.TestCase):
    """Container for tests of the class functionality."""
    def __init__(self, *args, **kwargs):
        super(TestClusterExpansionTernary, self).__init__(*args, **kwargs)
        self.structure = bulk('Au')
        self.cutoffs = [3.0] * 3
        chemical_symbols = ['Au', 'Pd', 'Ag']
        self.cs = ClusterSpace(self.structure, self.cutoffs, chemical_symbols)

    def shortDescription(self):
        """Silences unittest from printing the docstrings in test cases."""
        return None

    def setUp(self):
        """Setup before each test."""
        params_len = len(self.cs)
        self.parameters = np.arange(params_len)
        self.ce = ClusterExpansion(self.cs, self.parameters)

    def test_prune_cluster_expansion_with_indices(self):
        """Tests pruning cluster expansion."""

        self.ce.prune(indices=[1, 2, 3, 4, 5])

    def test_prune_cluster_expansion_with_tol(self):
        """Tests pruning cluster expansion."""
        # Prune everything
        self.ce.prune(tol=1e3)
        self.assertEqual(len(self.ce), 1)

    def test_prune_pairs(self):
        """Tests pruning pairs only"""

        df = self.ce.to_dataframe()
        pair_indices = df.index[df['order'] == 2].tolist()
        self.ce.prune(indices=pair_indices)

        df_new = self.ce.to_dataframe()
        pair_indices_new = df_new.index[df_new['order'] == 2].tolist()
        self.assertEqual(pair_indices_new, [])

    def test_property_metadata(self):
        """ Test metadata property. """

        user_metadata = dict(parameters=[1, 2, 3], fit_method='ardr')
        ce = ClusterExpansion(self.cs, self.parameters, metadata=user_metadata)
        metadata = ce.metadata

        # check for user metadata
        self.assertIn('parameters', metadata.keys())
        self.assertIn('fit_method', metadata.keys())

        # check for default metadata
        self.assertIn('date_created', metadata.keys())
        self.assertIn('username', metadata.keys())
        self.assertIn('hostname', metadata.keys())
        self.assertIn('icet_version', metadata.keys())

    def test_property_primitive_structure(self):
        """ Test primitive_structure property.. """
        prim = self.cs.primitive_structure
        self.assertEqual(prim, self.ce.primitive_structure)