def from_band_structure( cls, band_structure: BandStructure, energy_cutoff=defaults["energy_cutoff"], symprec=defaults["symprec"], ): kpoints = np.array([k.frac_coords for k in band_structure.kpoints]) efermi = band_structure.efermi structure = band_structure.structure full_kpoints, _, _, _, _, ir_to_full_idx = expand_kpoints( structure, kpoints, symprec=symprec, return_mapping=True, time_reversal=True) ibands = get_ibands(energy_cutoff, band_structure) vb_idx = get_vb_idx(energy_cutoff, band_structure) energies = {s: e[ibands[s]] for s, e in band_structure.bands.items()} energies = {s: e[:, ir_to_full_idx] for s, e in energies.items()} projections = { s: p[ibands[s]] for s, p in band_structure.projections.items() } band_centers = get_band_centers(full_kpoints, energies, vb_idx, efermi) rotation_mask = get_rotation_mask(projections) full_projections = {} for spin, spin_projections in projections.items(): nbands = spin_projections.shape[0] nkpoints = len(full_kpoints) spin_projections = spin_projections[:, ir_to_full_idx].reshape( (nbands, nkpoints, -1), order="F") spin_projections /= np.linalg.norm(spin_projections, axis=2)[..., None] spin_projections[np.isnan(spin_projections)] = 0 full_projections[spin] = spin_projections return cls.from_data( full_kpoints, full_projections, rotation_mask=rotation_mask, band_centers=band_centers, )
def from_coefficients( cls, coefficients, gpoints, kpoints, structure, symprec=defaults["symprec"] ): logger.info("Initializing wavefunction overlap calculator") mesh_dim = get_mesh_from_kpoint_numbers(kpoints) if np.product(mesh_dim) == len(kpoints): return cls(kpoints, coefficients, gpoints) full_kpoints, *symmetry_mapping = expand_kpoints( structure, kpoints, time_reversal=True, return_mapping=True, symprec=symprec ) coefficients = desymmetrize_coefficients( coefficients, gpoints, kpoints, structure, *symmetry_mapping ) return cls(full_kpoints, coefficients, gpoints)
def test_expand_kpoints(symmetry_structure, shift, mesh): def _kpoints_to_first_bz(kp): """helper function to map k-points to 1st BZ""" kp = np.array(kp) kp = kp - np.round(kp) kp[kp.round(8) == -0.5] = 0.5 return kp def _sort_kpoints(kp): """Helper function to put k-points in a consistent order""" kp = kp.round(8) sort_idx = np.lexsort((kp[:, 2], kp[:, 1], kp[:, 0])) return kp[sort_idx] # generate true k-points and IR k-points using spglib atoms = AseAtomsAdaptor.get_atoms(symmetry_structure) mapping, addresses = get_ir_reciprocal_mesh(mesh, atoms, is_shift=shift) true_kpoints = addresses / mesh + shift / (mesh * 2) true_kpoints = _kpoints_to_first_bz(true_kpoints) true_kpoints_sort = _sort_kpoints(true_kpoints) ir_mapping = np.unique(mapping, return_index=False) ir_kpoints = true_kpoints[ir_mapping] # try to expand the irreducible k-points back to the full BZ full_kpoints, rots, _, _, op_mapping, kp_mapping = expand_kpoints( symmetry_structure, ir_kpoints, return_mapping=True ) full_kpoints = _kpoints_to_first_bz(full_kpoints) full_kpoints_sort = _sort_kpoints(full_kpoints) # assert final k-points match the expected true k-points diff = np.linalg.norm(full_kpoints_sort - true_kpoints_sort, axis=1) assert np.max(diff) == 0 # now ensure that the rotation mapping actually works rotated_kpoints = [] for r, k in zip(op_mapping, kp_mapping): rotated_kpoints.append(np.dot(rots[r], ir_kpoints[k])) rotated_kpoints = _kpoints_to_first_bz(rotated_kpoints) rotated_kpoints_sort = _sort_kpoints(rotated_kpoints) # assert rotated k-points match the expected true k-points diff = np.linalg.norm(rotated_kpoints_sort - true_kpoints_sort, axis=1) assert np.max(diff) == 0
def from_deformation_potentials(cls, deformation_potentials, kpoints, structure, symprec=defaults["symprec"]): logger.info("Initializing deformation potential interpolator") mesh_dim = get_mesh_from_kpoint_numbers(kpoints) if np.product(mesh_dim) == len(kpoints): return cls.from_data(kpoints, deformation_potentials) full_kpoints, rotations, _, _, op_mapping, kp_mapping = expand_kpoints( structure, kpoints, time_reversal=True, return_mapping=True, symprec=symprec) logger.warning( "Desymmetrizing deformation potentials, this could go wrong.") deformation_potentials = desymmetrize_deformation_potentials( deformation_potentials, structure, rotations, op_mapping, kp_mapping) return cls.from_data(full_kpoints, deformation_potentials)