from transfer.frontends.simbaahf import SIMBASnapshotData
from scipy.spatial import cKDTree
import numpy

initial_snap = SIMBASnapshotData(
    '/cosma6/data/dp004/dc-borr1/simba-test-data/snap_m50n512_000.hdf5',
    halo_filename=None)

truncate_ids = {
    0: initial_snap.gas.particle_ids.max() + 1,
    1: None,
    4: initial_snap.gas.particle_ids.max() + 1,
}

final = SIMBASnapshotData(
    '/cosma6/data/dp004/dc-borr1/simba-test-data/snap_m50n512_151.hdf5',
    '/cosma6/data/dp004/dc-borr1/simba-test-data/snap151Rpep..z0.000.AHF_halos',
    truncate_ids=truncate_ids)

boxsize = final.boxsize
boxsize.convert_to_units('Mpc')
dm_coords = final.dark_matter.coordinates
halo_coords = final.halo_coordinates
dm_coords.convert_to_units('Mpc')
halo_coords.convert_to_units('Mpc')

tree = cKDTree(halo_coords, boxsize=boxsize.value)
d, i = tree.query(dm_coords, k=1, n_jobs=-1)

numpy.savetxt(
    '/cosma5/data/durham/dc-murr1/simba_dm_nearest_halo_distance.txt', d)
from transfer.spreadmetric import SpreadMetricCalculator
from transfer.frontends.simbaahf import SIMBASnapshotData
from transfer.holder import SimulationData
import pickle

initial_filename = "/cosma6/data/dp004/dc-borr1/simba-test-data/snap_m50n512_000.hdf5"
final_filename = "/cosma6/data/dp004/dc-borr1/simba-test-data/snap_m50n512_151.hdf5"

initial_halo_filename = None
final_halo_filename = (
    "/cosma6/data/dp004/dc-borr1/simba-test-data/snap151Rpep..z0.000.AHF_halos"
)

initial = SIMBASnapshotData(initial_filename, initial_halo_filename)

truncate_ids = {
    0: initial.gas.particle_ids.max() + 1,
    1: None,
    4: initial.gas.particle_ids.max() + 1,
}

final = SIMBASnapshotData(final_filename,
                          final_halo_filename,
                          truncate_ids=truncate_ids)
cross = SimulationData(initial_snapshot=initial, final_snapshot=final)

for particle_type in ["dark_matter", "stars", "gas"]:
    x = {
        n: getattr(getattr(cross, f"{particle_type}_transfer"), n)
        for n in [
            "in_halo",