def test_get_thm_relative_grid_addresses_nacl(nacl_lattice):
    rec_lat = np.linalg.inv(nacl_lattice)
    relative_addresses = get_thm_relative_grid_addresses(rec_lat)
    t1 = [[0, 0, 0], [1, 0, 0], [1, 1, 0], [1, 1, 1]]
    t24 = [[0, 0, 0], [-1, -1, -1], [-1, -1, 0], [-1, 0, 0]]
    np.testing.assert_array_equal(relative_addresses[0], t1)
    np.testing.assert_array_equal(relative_addresses[23], t24)
def test_get_thm_integration_weight_tio2(tio2_lattice,
                                         tio2_phonon_frequences_grg):

    dos_str_df_1 = """0.0 0.00000000 0.00000000
2.5 0.33432705 0.22888315
5.0 2.51042492 2.94515722
7.5 0.85214354 5.32809578
10.0 1.22209820 8.51697799
12.5 1.36331347 10.72763736
15.0 1.84840716 13.72854085
17.5 0.25849856 14.98248161
20.0 0.05704757 16.04840272
22.5 0.40515558 16.44481785
25.0 0.00814283 17.99947146"""

    dos_ref = np.fromstring(dos_str_df_1, sep=' ').reshape(-1, 3)
    freqs = tio2_phonon_frequences_grg
    grid_matrix = [[0, 11, 11], [11, 0, 11], [4, 4, 0]]
    D_diag = [1, 11, 88]
    P = [[0, -1, 3], [1, 0, 0], [-4, 4, -11]]
    num_gps = np.prod(D_diag)
    num_bands = freqs.shape[1]
    shift = [0, 0, 0]
    rec_lat = np.linalg.inv(tio2_lattice)
    microzone = np.dot(rec_lat, np.linalg.inv(grid_matrix))
    grid_addresses = get_all_grgrid_addresses(D_diag)
    # print(grid_addresses[:20])
    relative_addresses = get_thm_relative_grid_addresses(microzone)
    gr_relative_addresses = np.dot(relative_addresses, np.transpose(P))
    fpoints = dos_ref[:, 0]
    dos = np.zeros_like(fpoints)
    acc = np.zeros_like(fpoints)
    for ga in grid_addresses:
        tetrahedra_gps = _get_tetrahedra_grgrid_indices(
            ga + gr_relative_addresses, D_diag, shift)
        tetrahedra_freqs = freqs[tetrahedra_gps]
        for i, fpt in enumerate(fpoints):
            for j in range(num_bands):
                dos[i] += get_thm_integration_weight(fpt,
                                                     tetrahedra_freqs[:, :, j])
                acc[i] += get_thm_integration_weight(fpt,
                                                     tetrahedra_freqs[:, :, j],
                                                     function='J')

    dos_dat = np.array([fpoints, dos / num_gps, acc / num_gps]).T
    np.testing.assert_allclose(dos_dat, dos_ref, atol=1e-5)
def test_get_thm_integration_weight(nacl_lattice,
                                    nacl_phonon_frequences_101010):
    dos_str_df_1 = """0 6.695122056070627165e-05 8.259314625201316929e-07
1 8.911857347254222017e-02 2.553382929530613465e-02
2 4.685384246034071665e-01 2.649562989742248464e-01
3 1.697958549809626128e+00 1.362664683321239689e+00
4 2.388574749472669012e+00 2.596688544653257491e+00
5 3.892363708541792366e+00 4.946293094871299090e+00
6 5.985385823393145621e-01 5.703047775018656118e+00
7 6.634899679165588704e-02 5.991502313827999693e+00"""

    dos_ref = np.fromstring(dos_str_df_1, sep=' ').reshape(-1, 3)
    freqs = nacl_phonon_frequences_101010
    mesh = [10, 10, 10]
    num_gps = np.prod(mesh)
    num_bands = 6
    shift = [0, 0, 0]
    rec_lat = np.linalg.inv(nacl_lattice)
    grid_addresses = get_all_grid_addresses(mesh)
    relative_addresses = get_thm_relative_grid_addresses(rec_lat)
    fpoints = dos_ref[:, 0]
    dos = np.zeros_like(fpoints)
    acc = np.zeros_like(fpoints)
    for ga in grid_addresses:
        tetrahedra_gps = _get_tetrahedra_grid_indices(ga + relative_addresses,
                                                      mesh, shift)
        tetrahedra_freqs = freqs[tetrahedra_gps]
        for i, fpt in enumerate(fpoints):
            for j in range(num_bands):
                dos[i] += get_thm_integration_weight(fpt,
                                                     tetrahedra_freqs[:, :, j])
                acc[i] += get_thm_integration_weight(fpt,
                                                     tetrahedra_freqs[:, :, j],
                                                     function='J')

    dos_dat = np.array([fpoints, dos / num_gps, acc / num_gps]).T
    np.testing.assert_allclose(dos_dat, dos_ref, atol=1e-5)