def from_band_structure( cls, band_structure: BandStructure, energy_cutoff=defaults["energy_cutoff"], symprec=defaults["symprec"], ): kpoints = np.array([k.frac_coords for k in band_structure.kpoints]) efermi = band_structure.efermi structure = band_structure.structure full_kpoints, ir_to_full_idx, rot_mapping = expand_kpoints( structure, kpoints, symprec=symprec ) ibands = get_ibands(energy_cutoff, band_structure) vb_idx = get_vb_idx(energy_cutoff, band_structure) # energies = { # s: e[ibands[s], ir_to_full_idx] for s, e in band_structure.bands.items() # } energies = {s: e[ibands[s]] for s, e in band_structure.bands.items()} energies = {s: e[:, ir_to_full_idx] for s, e in energies.items()} projections = {s: p[ibands[s]] for s, p in band_structure.projections.items()} band_centers = get_band_centers(full_kpoints, energies, vb_idx, efermi) return cls( structure, kpoints, projections, band_centers, kpoint_symmetry_mapping=(full_kpoints, ir_to_full_idx, rot_mapping), )
def __init__( self, structure, kpoints, projections, band_centers, symprec=defaults["symprec"], kpoint_symmetry_mapping=None, ): logger.info("Initializing orbital overlap calculator") # k-points have to be on a regular grid, even if only the irreducible part of # the grid is used. If the irreducible part is given, we have to expand it # to the full BZ. Also need to expand the projections to the full BZ using # the rotation mapping if kpoint_symmetry_mapping: full_kpoints, ir_to_full_idx, rot_mapping = kpoint_symmetry_mapping else: full_kpoints, ir_to_full_idx, rot_mapping = expand_kpoints( structure, kpoints, symprec=symprec) mesh_dim = get_mesh_dim_from_kpoints(full_kpoints) round_dp = int(np.log10(1 / 1e-6)) full_kpoints = np.round(full_kpoints, round_dp) # get the indices to sort the k-points on the Z, then Y, then X columns sort_idx = np.lexsort( (full_kpoints[:, 2], full_kpoints[:, 1], full_kpoints[:, 0])) # put the kpoints into a 3D grid so that they can be indexed as # kpoints[ikx][iky][ikz] = [kx, ky, kz] grid_kpoints = full_kpoints[sort_idx].reshape(mesh_dim + (3, )) x = grid_kpoints[:, 0, 0, 0] y = grid_kpoints[0, :, 0, 1] z = grid_kpoints[0, 0, :, 2] self.nbands = {s: p.shape[0] for s, p in projections.items()} # TODO: Expand the k-point mesh to account for periodic boundary conditions self.interpolators = {} for spin, spin_projections in projections.items(): nbands = spin_projections.shape[0] nkpoints = len(ir_to_full_idx) nprojections = np.product(spin_projections.shape[2:]) expand_projections = spin_projections[:, ir_to_full_idx] flat_projections = expand_projections.reshape( (nbands, nkpoints, -1), order="F") # aim is to get the wavefunction coefficients norm_projection = (flat_projections / np.sqrt( (flat_projections**2).sum(axis=2))[..., None]) norm_projection[np.isnan(norm_projection)] = 0 coefficients = norm_projection # sort the coefficients then reshape them into the grid. The coefficients # can now be indexed as coefficients[iband][ikx][iky][ikz] sorted_coefficients = coefficients[:, sort_idx] grid_shape = (nbands, ) + mesh_dim + (nprojections, ) grid_coefficients = sorted_coefficients.reshape(grid_shape) if nbands == 1: # this can cause a bug in RegularGridInterpolator. Have to fake # having at least two bands nbands = 2 grid_coefficients = np.tile(grid_coefficients, (2, 1, 1, 1, 1)) interp_range = (np.arange(nbands), x, y, z) self.interpolators[spin] = RegularGridInterpolator( interp_range, grid_coefficients, bounds_error=False, fill_value=None) self.rotation_masks = get_rotation_masks(projections) self.band_centers = band_centers