class StructureAnalyzer(object):
    def __init__(self, atoms):
        self._atoms = atoms
        self._filename = None
        self._dictionary = {}

    def update_attributes(self):
        """Update attributes of the class Poscar from atoms attribute.

        This method should be called after tne update of atoms attribute.
        """
        self.generate_dictionary()
        self.deform_cell = self.deform_cell_right  # alias

    def generate_dictionary(self):

        cell = self._atoms.get_cell()
        number_of_atoms = self._atoms.get_number_of_atoms()
        chemical_symbols = self._atoms.get_chemical_symbols()

        dictionary = convert_cell_to_lc(cell)

        volume = dictionary["volume"]

        dictionary.update({
            "filename": self._filename,
            "number_of_atoms": number_of_atoms,
            "chemical_symbols": chemical_symbols,
            "volume_per_atom": volume / number_of_atoms,
        })

        self._dictionary = dictionary

        self.create_symmetry_dataset()

        return self

    def create_symmetry_dataset(self):
        symmetry_dataset = self.get_symmetry_dataset()
        self._dictionary.update({
            "spg_number":
            symmetry_dataset["number"],
            "spg_international":
            symmetry_dataset["international"],
        })

    def calculate_radial_distribution_functions(self,
                                                sigma=0.1,
                                                xmin=0.0,
                                                xmax=3.0,
                                                xpitch=0.01,
                                                dim=2):
        from ph_unfolder.analysis.smearing import Smearing

        self.generate_supercell(dim=dim)

        print("Warning: this method is under development!")
        # TODO(ikeda): Consider atoms over boundaries
        atoms = self._atoms
        cell = atoms.get_cell()
        natoms = atoms.get_number_of_atoms()
        density = natoms / np.linalg.det(cell)
        print("Calculating distance matrix:", end="")
        self.generate_distance_matrix()
        print(" Finished.")
        distance_matrix = self.get_distance_matrix()
        smearing = Smearing(sigma=sigma, xmin=xmin, xmax=xmax, xpitch=xpitch)
        xs = smearing.get_xs()
        rdfs = []
        for distances in distance_matrix:
            distances = np.sort(distances)[1:]  # To avoid the same atom
            weights = 1.0 / (4.0 * np.pi * distances**2)
            rdf = smearing.run(peaks=distances, weights=weights)
            rdfs.append(rdf)
        rdfs = np.array(rdfs)
        rdfs /= density
        return xs, rdfs

    def calculate_distances_from_position(self, position_checked):
        """

        Args:
            position: Scaled positions from which the distances are measured.
        """
        cell = self._atoms.get_cell()
        scaled_positions = self._atoms.get_scaled_positions()
        number_of_atoms = self._atoms.get_number_of_atoms()

        expansion = range(-1, 2)
        distances = np.zeros(number_of_atoms) * np.nan
        scaled_distances = np.zeros((number_of_atoms, 3)) * np.nan
        for i2, p2 in enumerate(scaled_positions):
            distance = np.inf
            for addition in itertools.product(expansion, repeat=3):
                scaled_distance_new = p2 - position_checked
                scaled_distance_new -= np.rint(scaled_distance_new)
                scaled_distance_new += addition
                distance_new = np.linalg.norm(
                    np.dot(cell.T, scaled_distance_new))
                if distance > distance_new:
                    distance = distance_new
                    scaled_distance = scaled_distance_new
            distances[i2] = distance
            scaled_distances[i2] = scaled_distance
        return distances, scaled_distances

    def generate_distance_matrix(self):
        dm, diffs = self._create_distance_matrix_expanded()
        self._distance_matrix = dm[..., 0]
        self._scaled_distances = diffs[..., 0, :]
        return self

    def _create_distance_matrix_expanded(self):
        cell = self._atoms.get_cell()
        scaled_positions = self._atoms.get_scaled_positions()

        expansion = range(-1, 2)
        additions = list(itertools.product(expansion, repeat=3))

        diffs = scaled_positions[None, :, :] - scaled_positions[:, None, :]
        diffs -= np.rint(diffs)
        diffs = diffs[:, :, None, :] + additions

        dm = np.linalg.norm(np.dot(diffs, cell), axis=-1)
        indices = np.argsort(dm)

        tmp = np.indices(dm.shape)
        dm = dm[tmp[0], tmp[1], indices]

        tmp = np.indices(diffs.shape)
        diffs = diffs[tmp[0], tmp[1], indices[..., None]]

        return dm, diffs

    def write_properties(self, precision=16):
        width = precision + 6
        width_int = 5

        key_order = [
            "filename",
            "number_of_atoms",
            "volume",
            "volume_per_atom",
            "a",
            "b",
            "c",
            "b/a",
            "c/a",
            "a/b",
            "c/b",
            "a/c",
            "b/c",
            "alpha",
            "beta",
            "gamma",
            "b_x_c",
            "c_x_a",
            "a_x_b",
            "spg_number",
            "spg_international",
        ]

        print("-" * 80)
        print(self._filename)
        print("-" * 80)
        for k in key_order:
            if k not in self._dictionary:
                continue
            value = self._dictionary[k]
            sys.stdout.write("{:s}".format(k))
            sys.stdout.write(": ")
            if isinstance(value, float):
                sys.stdout.write("{:{width}.{precision}f}".format(
                    value,
                    width=width,
                    precision=precision,
                ))
            elif isinstance(value, six.integer_types):
                sys.stdout.write("{:{width}d}".format(
                    value,
                    width=width_int,
                ))
            else:
                sys.stdout.write("{:s}".format(value))
            sys.stdout.write("\n")

    def write_specified_properties(self, keys, precision=16):
        width = precision + 6
        width_int = 5
        for k in keys:
            value = self._dictionary[k]
            sys.stdout.write(" ")
            sys.stdout.write("{:s}".format(k))
            sys.stdout.write(" ")
            if isinstance(value, float):
                sys.stdout.write("{:{width}.{precision}f}".format(
                    value,
                    width=width,
                    precision=precision,
                ))
            elif isinstance(value, six.integer_types):
                sys.stdout.write("{:{width}d}".format(
                    value,
                    width=width_int,
                ))
                sys.stdout.write(" " * (precision + 1))
            else:
                sys.stdout.write("{:s}".format(value))
        sys.stdout.write("\n")

    def get_index_from_position(self, position, symprec=1e-6):

        for i, p in enumerate(self._atoms.get_scaled_positions()):
            diff = position - p
            diff -= np.rint(diff)
            if all([abs(x) < symprec for x in diff]):
                return i
        print("WARNING: {}".format(__name__))
        print("Index for the specified position cannot be found.")
        return None

    def write_distance_matrix(self):
        number_of_atoms = self._atoms.get_number_of_atoms()
        for i1 in range(number_of_atoms):
            for i2 in range(number_of_atoms):
                distance = self._distance_matrix[i1, i2]
                sys.stdout.write("{:22.16f}".format(distance))
            sys.stdout.write("\n")

    def write_sorted_distance_matrix(self):
        """Write distances between an atom and another one.
        """
        number_of_atoms = self._atoms.get_number_of_atoms()
        chemical_symbols = self._atoms.get_chemical_symbols()
        positions = self._atoms.get_scaled_positions()
        #        sys.stdout.write("# {:4d}\n".format(number_of_atoms))
        for i1, c1 in enumerate(chemical_symbols):
            distances_index = np.argsort(self._distance_matrix[i1])
            for i2 in distances_index:
                c2 = chemical_symbols[i2]
                d = self._distance_matrix[i1, i2]
                # dp = positions[i1] - positions[i2]
                dp = self._scaled_distances[i1, i2]
                sys.stdout.write("{:6d}".format(i1))
                sys.stdout.write("{:>6s}".format(c1))
                sys.stdout.write("{:6d}".format(i2))
                sys.stdout.write("{:>6s}".format(c2))
                sys.stdout.write("{:12.6f}".format(d))
                sys.stdout.write(" ")
                sys.stdout.write(("{:12.6f}" * 3).format(*dp))
                sys.stdout.write("\n")
            sys.stdout.write("\n")

    def set_atoms(self, atoms):
        self._atoms = atoms
        return self

    def set_scaled_positions(self, scaled_positions):
        self._atoms.set_scaled_positions(scaled_positions)
        return self

    def set_positions(self, positions):
        cell = self._atoms.get_cell()
        scaled_positions = np.dot(positions, np.linalg.inv(cell))
        return self.set_scaled_positions(scaled_positions)

    def displace_scaled_positions(self, scaled_displacements):
        scaled_positions = self._atoms.get_scaled_positions()
        scaled_positions += scaled_displacements
        self._atoms.set_scaled_positions(scaled_positions)
        return self

    def displace_positions(self, displacements):
        cell = self._atoms.get_cell()
        scaled_displacements = np.dot(displacements, np.linalg.inv(cell))
        return self.displace_scaled_positions(scaled_displacements)

    def shift_to_origin(self, index):
        scaled_displacements = -self._atoms.get_scaled_positions()[index]
        return self.displace_scaled_positions(scaled_displacements)

    def remove_atoms_indices(self, indices):
        if isinstance(indices, int):
            indices = [indices]
        indices = sorted(set(indices), reverse=True)
        scaled_positions = self._atoms.get_scaled_positions()
        chemical_symbols = self._atoms.get_chemical_symbols()
        for i in indices:
            scaled_positions = np.delete(scaled_positions, i, 0)
            del chemical_symbols[i]
        self.set_scaled_positions(scaled_positions)
        self.set_chemical_symbols(chemical_symbols)
        self._atoms._symbols_to_numbers()
        self._atoms._symbols_to_masses()
        return self

    def remove_atoms_outside(self, region):
        """

        region: 3 x 2 arrays given by direct coordinates.
            [[a-, a+],
             [b-, b+],
             [c-, c+]]
        """
        self.wrap_into_cell()
        region = np.array(region)
        scaled_positions = self._atoms.get_scaled_positions()
        indices_removed = []
        for ix in range(3):
            for i, sp in enumerate(scaled_positions):
                if (sp[ix] < region[ix, 0] or region[ix, 1] < sp[ix]):
                    indices_removed.append(i)
        return self.remove_atoms_indices(indices_removed)

    def add_vacuum_layer(self, vacuum_layer):
        """

        vacuum_layer: 3 x 2 arrays given by direct coordinates.
            [[a-, a+],
             [b-, b+],
             [c-, c+]]
        """
        self.wrap_into_cell()
        vacuum_layer = np.array(vacuum_layer)
        cell = self._atoms.get_cell()
        scaled_positions = self._atoms.get_scaled_positions()
        natoms = self._atoms.get_number_of_atoms()
        for ix in range(3):
            cell[ix] *= (1.0 + sum(vacuum_layer[ix, :]))
            for i in range(natoms):
                scaled_positions[i, ix] += vacuum_layer[ix, 0]
                scaled_positions[i, ix] /= (1.0 + sum(vacuum_layer[ix, :]))
        self.set_cell(cell)
        self.set_scaled_positions(scaled_positions)
        return self

    def wrap(self, center=(0.5, 0.5, 0.5)):
        fractional = self._atoms.get_scaled_positions()
        shift = np.array(center) - (0.5, 0.5, 0.5)
        fractional -= shift
        fractional -= np.floor(fractional)
        fractional += shift
        self.set_scaled_positions(fractional)
        return self

    def set_cell(self, cell):
        self._atoms.set_cell(cell)
        return self

    def set_chemical_symbols(self, symbols):
        self._atoms.set_chemical_symbols(symbols)
        return self

    def deform_cell_left(self, matrix):
        """Deform cell as (a, b, c) = M * (a, b, c)
        """
        matrix = _get_matrix(matrix)
        # Generate lattice vectors for the deformed cell.
        cell = self._atoms.get_cell()
        cell = np.dot(matrix, cell.T).T
        self._atoms.set_cell(cell)

        self.update_attributes()

        return self

    def deform_cell_right(self, matrix):
        """Deform cell as (a, b, c) = (a, b, c) * M
        """
        matrix = _get_matrix(matrix)
        # Generate lattice vectors for the deformed cell.
        cell = self._atoms.get_cell()
        cell = np.dot(cell.T, matrix).T
        self._atoms.set_cell(cell)

        self.update_attributes()

        return self

    def generate_supercell(self, dim, prec=1e-9):
        """Generate supercell according to "dim".

        (a_s, b_s, c_s) = (a_u, b_u, c_u) * dim
        """
        dim = _get_matrix(dim)
        nexpansion = np.int(np.rint(np.abs(np.linalg.det(dim))))

        # Generate lattice vectors for the suprecell.
        cell = self._atoms.get_cell()
        cell = np.dot(cell.T, dim).T

        chemical_symbols_new = []
        for chemical_symbol in self._atoms.get_chemical_symbols():
            chemical_symbols_new += [chemical_symbol] * nexpansion

        supercell_positions = self._generate_supercell_positions(dim, prec)

        self._atoms = Atoms(cell=cell,
                            symbols=chemical_symbols_new,
                            scaled_positions=supercell_positions)
        self.update_attributes()

        return self

    def _generate_supercell_positions(self, dim, prec=1e-9):
        """Generate scaled positions in the supercell."""
        translation_vectors = find_lattice_vectors(dim, prec=prec)

        positions = self._atoms.get_scaled_positions()

        # Convert positions to into the fractional coordinates for SC.
        positions = np.dot(np.linalg.inv(dim), positions.T).T

        supercell_positions = (positions[:, None] +
                               translation_vectors[None, :])
        supercell_positions = supercell_positions.reshape(-1, 3)
        return supercell_positions

    def sort_by_coordinates(self, index, sorted_by_symbols=False):
        """

        index:
            0: a, 1: b, 2: c
        """
        symbols = self._atoms.get_chemical_symbols()
        positions = self._atoms.get_scaled_positions()
        order = list(symbols)
        data = zip(symbols, positions)
        data = sorted(data, key=lambda x: x[1][index])
        self.set_chemical_symbols(zip(*data)[0])
        self.set_scaled_positions(zip(*data)[1])
        if sorted_by_symbols:
            self.sort_by_symbols(order=order)
        self._atoms._symbols_to_numbers()
        self._atoms._symbols_to_masses()
        return self

    def sort_by_symbols(self, order=None, atomic_properties=None):
        """Combine the same chemical symbols.

        Positions are sorted by the combined chemical symbols.
        """
        symbols = self._atoms.get_chemical_symbols()
        positions = self._atoms.get_scaled_positions()
        if order is None:
            order = list(symbols)

        if atomic_properties is None:
            data = zip(symbols, positions)
        else:
            data = zip(symbols, positions, *atomic_properties.values())

        data = sorted(data, key=lambda x: order.index(x[0]))
        self.set_chemical_symbols(zip(*data)[0])
        self.set_scaled_positions(zip(*data)[1])
        self._atoms._symbols_to_numbers()
        self._atoms._symbols_to_masses()

        if atomic_properties is None:
            atomic_properties_sorted = None
        else:
            atomic_properties_sorted = {
                k: v
                for k, v in zip(atomic_properties.keys(),
                                zip(*data)[2:])
            }

        return atomic_properties_sorted

    def get_dictionary(self):
        return self._dictionary

    def get_atoms(self):
        return self._atoms

    def get_cell(self):
        return self._atoms.get_cell()

    def get_scaled_distances(self):
        return self._scaled_distances.copy()

    def get_distance_matrix(self):
        return self._distance_matrix.copy()

    def change_volume(self, volume):
        cell_current = self._atoms.get_cell()
        volume_current = np.linalg.det(cell_current)
        scale = (volume / volume_current)**(1.0 / 3.0)
        self._atoms.set_cell(cell_current * scale)
        return self

    def change_volume_per_atom(self, volume_per_atom):
        volume = volume_per_atom * self._atoms.get_number_of_atoms()
        return self.change_volume(volume)

    def get_symmetry_dataset(self, *args, **kwargs):
        return Symmetry(self._atoms, *args, **kwargs).get_dataset()

    def get_mappings_for_symops(self, prec=1e-6):
        """Get mappings for symmetry operations."""
        natoms = self._atoms.get_number_of_atoms()

        dataset = self.get_symmetry_dataset()
        rotations = dataset["rotations"]
        translations = dataset["translations"]
        nopr = len(rotations)
        mappings = -1 * np.ones((nopr, natoms), dtype=int)
        for iopr, (r, t) in enumerate(zip(rotations, translations)):
            mappings[iopr] = self.extract_mapping_for_symopr(r, t, prec)[0]

        if -1 in mappings:
            print("ERROR: {}".format(__name__))
            print("Some atoms are not mapped by some symmetry operations.")
            raise ValueError

        return mappings

    def extract_transformed_scaled_positions(self, rotation, translation):
        """Extract transformed scaled positions.

        Args:
            rotation (3x3 array): Rotation matrix.
            translation (3 array): Translation vector.

        Returns:
            Transformed scaled positions by the rotation and translation.
            Note that if the rotation and the translation is not a symmetry
            operations, the returned values could be strange.
        """
        scaled_positions = self._atoms.get_scaled_positions()
        transformed_scaled_positions = transform_scaled_positions(
            scaled_positions, rotation, translation)
        return transformed_scaled_positions

    def extract_mapping_for_symopr(self, rotation, translation, prec=1e-6):
        """Extract a mapping for a pair of a symmetry operation.

        Args:
            rotation (3x3 array): Rotation matrix.
            translation (3 array): Translation vector.

        Returns:
            mapping (n integral array):
                Indices are for new numbers and contents are for old ones.
        """
        chemical_symbols = self._atoms.get_chemical_symbols()
        transformed_scaled_positions = (
            self.extract_transformed_scaled_positions(rotation, translation))
        mapping, diff_positions = self.extract_mapping_for_atoms(
            chemical_symbols, transformed_scaled_positions, prec)

        return mapping, diff_positions

    def extract_mapping_for_atoms(self, symbols_new, positions_new, prec=1e-6):
        """
        Args:
            symbols_new: Chemical symbols for the transformed structures.
            positions_new: Fractional positions for the transformed structures.

        Return:
            mapping (n integral array):
                Indices are for new numbers and contents are for old ones.
                mapping[i] == j means that the i-th atom moves to the position
                of the j-th atom.
        """
        natoms = self._atoms.get_number_of_atoms()
        symbols_old = self._atoms.get_chemical_symbols()
        positions_old = self._atoms.get_scaled_positions()

        mapping = -1 * np.ones(natoms, dtype=int)
        diff_positions = np.zeros((natoms, 3), dtype=int)
        for iatoms, sp_trn in enumerate(positions_new):
            for jatoms, sp_orig in enumerate(positions_old):
                if symbols_new[iatoms] != symbols_old[jatoms]:
                    continue
                diff = sp_trn - sp_orig
                wrapped_dpos = diff - np.rint(diff)

                if (np.abs(wrapped_dpos) < prec).all():
                    mapping[iatoms] = jatoms
                    diff_positions[iatoms] = np.rint(diff).astype(int)
                    break

        return mapping, diff_positions