示例#1
0
    def __init__(self, structure, kpoints, coefficients):
        logger.info("Initializing wavefunction overlap calculator")
        self.structure = structure

        # k-points has to cover the full BZ
        kpoints = kpoints_to_first_bz(kpoints)
        mesh_dim = get_mesh_dim_from_kpoints(kpoints, tol=1e-4)

        round_dp = int(np.log10(1 / 1e-6))
        kpoints = np.round(kpoints, round_dp)

        # get the indices to sort the k-points on the Z, then Y, then X columns
        sort_idx = np.lexsort((kpoints[:, 2], kpoints[:, 1], kpoints[:, 0]))

        # put the kpoints into a 3D grid so that they can be indexed as
        # kpoints[ikx][iky][ikz] = [kx, ky, kz]
        grid_kpoints = kpoints[sort_idx].reshape(mesh_dim + (3, ))

        x = grid_kpoints[:, 0, 0, 0]
        y = grid_kpoints[0, :, 0, 1]
        z = grid_kpoints[0, 0, :, 2]

        self.nbands = {s: c.shape[0] for s, c in coefficients.items()}

        # TODO: Expand the k-point mesh to account for periodic boundary conditions
        self.interpolators = {}
        for spin, spin_coefficients in coefficients.items():
            nbands = spin_coefficients.shape[0]
            ncoefficients = spin_coefficients.shape[-1]

            # sort the coefficients then reshape them into the grid. The coefficients
            # can now be indexed as coefficients[iband][ikx][iky][ikz]
            sorted_coefficients = spin_coefficients[:, sort_idx]
            grid_shape = (nbands, ) + mesh_dim + (ncoefficients, )
            grid_coefficients = sorted_coefficients.reshape(grid_shape)

            if nbands == 1:
                # this can cause a bug in RegularGridInterpolator. Have to fake
                # having at least two bands
                nbands = 2
                grid_coefficients = np.tile(grid_coefficients, (2, 1, 1, 1, 1))

            if eval_linear:
                grid = UCGrid(
                    (0, nbands - 1, nbands),
                    (x[0], x[-1], len(x)),
                    (y[0], y[-1], len(y)),
                    (z[0], z[-1], len(z)),
                )
                self.interpolators[spin] = (grid, grid_coefficients)
            else:
                interp_range = (np.arange(nbands), x, y, z)

                self.interpolators[spin] = RegularGridInterpolator(
                    interp_range,
                    grid_coefficients,
                    bounds_error=False,
                    fill_value=None)
示例#2
0
文件: overlap.py 项目: kcbhamu/amset
    def __init__(
        self,
        structure,
        kpoints,
        projections,
        band_centers,
        symprec=defaults["symprec"],
        kpoint_symmetry_mapping=None,
    ):
        logger.info("Initializing orbital overlap calculator")

        # k-points have to be on a regular grid, even if only the irreducible part of
        # the grid is used. If the irreducible part is given, we have to expand it
        # to the full BZ. Also need to expand the projections to the full BZ using
        # the rotation mapping
        if kpoint_symmetry_mapping:
            full_kpoints, ir_to_full_idx, rot_mapping = kpoint_symmetry_mapping
        else:
            full_kpoints, ir_to_full_idx, rot_mapping = expand_kpoints(
                structure, kpoints, symprec=symprec)

        mesh_dim = get_mesh_dim_from_kpoints(full_kpoints)

        round_dp = int(np.log10(1 / 1e-6))
        full_kpoints = np.round(full_kpoints, round_dp)

        # get the indices to sort the k-points on the Z, then Y, then X columns
        sort_idx = np.lexsort(
            (full_kpoints[:, 2], full_kpoints[:, 1], full_kpoints[:, 0]))

        # put the kpoints into a 3D grid so that they can be indexed as
        # kpoints[ikx][iky][ikz] = [kx, ky, kz]
        grid_kpoints = full_kpoints[sort_idx].reshape(mesh_dim + (3, ))

        x = grid_kpoints[:, 0, 0, 0]
        y = grid_kpoints[0, :, 0, 1]
        z = grid_kpoints[0, 0, :, 2]

        self.nbands = {s: p.shape[0] for s, p in projections.items()}

        # TODO: Expand the k-point mesh to account for periodic boundary conditions
        self.interpolators = {}
        for spin, spin_projections in projections.items():
            nbands = spin_projections.shape[0]
            nkpoints = len(ir_to_full_idx)
            nprojections = np.product(spin_projections.shape[2:])

            expand_projections = spin_projections[:, ir_to_full_idx]
            flat_projections = expand_projections.reshape(
                (nbands, nkpoints, -1), order="F")

            # aim is to get the wavefunction coefficients
            norm_projection = (flat_projections / np.sqrt(
                (flat_projections**2).sum(axis=2))[..., None])
            norm_projection[np.isnan(norm_projection)] = 0
            coefficients = norm_projection

            # sort the coefficients then reshape them into the grid. The coefficients
            # can now be indexed as coefficients[iband][ikx][iky][ikz]
            sorted_coefficients = coefficients[:, sort_idx]
            grid_shape = (nbands, ) + mesh_dim + (nprojections, )
            grid_coefficients = sorted_coefficients.reshape(grid_shape)

            if nbands == 1:
                # this can cause a bug in RegularGridInterpolator. Have to fake
                # having at least two bands
                nbands = 2
                grid_coefficients = np.tile(grid_coefficients, (2, 1, 1, 1, 1))

            interp_range = (np.arange(nbands), x, y, z)

            self.interpolators[spin] = RegularGridInterpolator(
                interp_range,
                grid_coefficients,
                bounds_error=False,
                fill_value=None)

        self.rotation_masks = get_rotation_masks(projections)
        self.band_centers = band_centers