def get_electron_number_density(sw_data: SWIFTDataset) -> cosmo_array:
    Xe, Xi, mu = get_molecular_weights(sw_data)
    electron_number_density = (
        Xe / (Xe + Xi)) * sw_data.gas.densities.to('g*cm**-3') / (mu * mp)

    electron_number_density.convert_to_units('cm**-3')
    electron_number_density = cosmo_array(
        electron_number_density.value,
        units=electron_number_density.units,
        cosmo_factor=sw_data.gas.densities.cosmo_factor)

    return electron_number_density
Ejemplo n.º 2
0
    def get_radial_distance(
            coords: swiftsimio.cosmo_array,
            centre: unyt_array
    ) -> swiftsimio.cosmo_array:
        result = swiftsimio.cosmo_array(
            distance.cdist(
                coords,
                centre.reshape(1, 3),
                metric='euclidean'
            ).reshape(len(coords), ),
            units='Mpc',
            cosmo_factor=coords.cosmo_factor
        )

        return result
Ejemplo n.º 3
0
    def wrap_coordinates(
            coords: swiftsimio.cosmo_array,
            centre: unyt_array,
            boxsize: unyt_array
    ) -> swiftsimio.cosmo_array:
        result_numeric = np.mod(
            coords.value - centre.value + 0.5 * boxsize.value,
            boxsize.value
        ) + centre.value - 0.5 * boxsize.value

        result = swiftsimio.cosmo_array(
            result_numeric,
            units=coords.units,
            cosmo_factor=coords.cosmo_factor
        )

        return result
Ejemplo n.º 4
0
def create_single_particle_dataset(filename: str, output_name: str):
    """
    Create an hdf5 snapshot with two particles at an identical location

    Parameters
    ----------
    filename: str
        name of file from which to copy metadata
    output_name: str
        name of single particle snapshot
    """
    # Create a dummy mask in order to write metadata
    data_mask = mask(filename)
    boxsize = data_mask.metadata.boxsize
    region = [[0, b] for b in boxsize]
    data_mask.constrain_spatial(region)

    # Write the metadata
    infile = h5py.File(filename, "r")
    outfile = h5py.File(output_name, "w")
    list_of_links, _ = find_links(infile)
    write_metadata(infile, outfile, list_of_links, data_mask)

    # Write a single particle
    particle_coords = cosmo_array([[1, 1, 1], [1, 1, 1]],
                                  data_mask.metadata.units.length)
    particle_masses = cosmo_array([1, 1], data_mask.metadata.units.mass)
    mean_h = mean(infile["/PartType0/SmoothingLengths"])
    particle_h = cosmo_array([mean_h, mean_h], data_mask.metadata.units.length)
    particle_ids = [1, 2]

    coords = outfile.create_dataset("/PartType0/Coordinates",
                                    data=particle_coords)
    for name, value in infile["/PartType0/Coordinates"].attrs.items():
        coords.attrs.create(name, value)

    masses = outfile.create_dataset("/PartType0/Masses", data=particle_masses)
    for name, value in infile["/PartType0/Masses"].attrs.items():
        masses.attrs.create(name, value)

    h = outfile.create_dataset("/PartType0/SmoothingLengths", data=particle_h)
    for name, value in infile["/PartType0/SmoothingLengths"].attrs.items():
        h.attrs.create(name, value)

    ids = outfile.create_dataset("/PartType0/ParticleIDs", data=particle_ids)
    for name, value in infile["/PartType0/ParticleIDs"].attrs.items():
        ids.attrs.create(name, value)

    # Get rid of all traces of DM
    del outfile["/Cells/Counts/PartType1"]
    del outfile["/Cells/Offsets/PartType1"]
    nparts_total = [2, 0, 0, 0, 0, 0]
    nparts_this_file = [2, 0, 0, 0, 0, 0]
    outfile["/Header"].attrs["NumPart_Total"] = nparts_total
    outfile["/Header"].attrs["NumPart_ThisFile"] = nparts_this_file

    # Tidy up
    infile.close()
    outfile.close()

    return
Ejemplo n.º 5
0
def feedback_stats_dT(path_to_snap: str, path_to_catalogue: str) -> dict:
    # Read in halo properties
    with h5py.File(f'{path_to_catalogue}', 'r') as h5file:
        XPotMin = unyt.unyt_quantity(h5file['/Xcminpot'][0], unyt.Mpc)
        YPotMin = unyt.unyt_quantity(h5file['/Ycminpot'][0], unyt.Mpc)
        ZPotMin = unyt.unyt_quantity(h5file['/Zcminpot'][0], unyt.Mpc)
        R500c = unyt.unyt_quantity(h5file['/SO_R_500_rhocrit'][0], unyt.Mpc)

    # Read in particles
    mask = sw.mask(f'{path_to_snap}', spatial_only=True)
    region = [[
        XPotMin - radius_bounds[1] * R500c, XPotMin + radius_bounds[1] * R500c
    ], [
        YPotMin - radius_bounds[1] * R500c, YPotMin + radius_bounds[1] * R500c
    ], [
        ZPotMin - radius_bounds[1] * R500c, ZPotMin + radius_bounds[1] * R500c
    ]]
    mask.constrain_spatial(region)
    data = sw.load(f'{path_to_snap}', mask=mask)

    # Get positions for all BHs in the bounding region
    bh_positions = data.black_holes.coordinates
    bh_coordX = bh_positions[:, 0] - XPotMin
    bh_coordY = bh_positions[:, 1] - YPotMin
    bh_coordZ = bh_positions[:, 2] - ZPotMin
    bh_radial_distance = np.sqrt(bh_coordX**2 + bh_coordY**2 + bh_coordZ**2)

    # The central SMBH will probably be massive.
    # Narrow down the search to the BH with top 5% in mass
    bh_masses = data.black_holes.subgrid_masses.to_physical()
    bh_top_massive_index = np.where(
        bh_masses > np.percentile(bh_masses.value, 95))[0]

    # Get the central BH closest to centre of halo at z=0
    central_bh_index = np.argmin(bh_radial_distance[bh_top_massive_index])
    central_bh_id_target = data.black_holes.particle_ids[bh_top_massive_index][
        central_bh_index]

    # Initialise typed dictionary for the central BH
    central_bh = defaultdict(list)
    central_bh['x'] = []
    central_bh['y'] = []
    central_bh['z'] = []
    central_bh['dx'] = []
    central_bh['dy'] = []
    central_bh['dz'] = []
    central_bh['dr'] = []
    central_bh['mass'] = []
    central_bh['m500c'] = []
    central_bh['id'] = []
    central_bh['redshift'] = []
    central_bh['time'] = []

    # Retrieve BH data from other snaps
    # Clip redshift data (get snaps below that redshifts)
    z_clip = 5.
    all_snaps = get_allpaths_from_last(path_to_snap, z_max=z_clip)
    all_catalogues = get_allpaths_from_last(path_to_catalogue, z_max=z_clip)
    assert len(all_snaps) == len(all_catalogues), (
        f"Detected different number of high-z snaps and high-z catalogues. "
        f"Number of snaps: {len(all_snaps)}. Number of catalogues: {len(all_catalogues)}."
    )

    for i, (highz_snap,
            highz_catalogue) in enumerate(zip(all_snaps, all_catalogues)):

        if not SILENT_PROGRESSBAR:
            print((f"Analysing snap ({i + 1}/{len(all_snaps)}):\n"
                   f"\t{os.path.basename(highz_snap)}\n"
                   f"\t{os.path.basename(highz_catalogue)}"))

        with h5py.File(f'{highz_catalogue}', 'r') as h5file:
            XPotMin = unyt.unyt_quantity(h5file['/Xcminpot'][0],
                                         unyt.Mpc) / data.metadata.a
            YPotMin = unyt.unyt_quantity(h5file['/Ycminpot'][0],
                                         unyt.Mpc) / data.metadata.a
            ZPotMin = unyt.unyt_quantity(h5file['/Zcminpot'][0],
                                         unyt.Mpc) / data.metadata.a
            M500c = unyt.unyt_quantity(
                h5file['/SO_Mass_500_rhocrit'][0] * 1.e10, unyt.Solar_Mass)

        data = sw.load(highz_snap)
        bh_positions = data.black_holes.coordinates.to_physical()
        bh_coordX = bh_positions[:, 0] - XPotMin
        bh_coordY = bh_positions[:, 1] - YPotMin
        bh_coordZ = bh_positions[:, 2] - ZPotMin
        bh_radial_distance = np.sqrt(bh_coordX**2 + bh_coordY**2 +
                                     bh_coordZ**2)
        bh_masses = data.black_holes.subgrid_masses.to_physical()

        if BH_LOCK_ID:
            central_bh_index = np.where(
                data.black_holes.particle_ids.v == central_bh_id_target.v)[0]
        else:
            central_bh_index = np.argmin(bh_radial_distance)

        central_bh['x'].append(bh_positions[central_bh_index, 0])
        central_bh['y'].append(bh_positions[central_bh_index, 1])
        central_bh['z'].append(bh_positions[central_bh_index, 2])
        central_bh['dx'].append(bh_coordX[central_bh_index])
        central_bh['dy'].append(bh_coordY[central_bh_index])
        central_bh['dz'].append(bh_coordZ[central_bh_index])
        central_bh['dr'].append(bh_radial_distance[central_bh_index])
        central_bh['mass'].append(bh_masses[central_bh_index])
        central_bh['m500c'].append(M500c)
        central_bh['id'].append(
            data.black_holes.particle_ids[central_bh_index])
        central_bh['redshift'].append(data.metadata.redshift)
        central_bh['time'].append(data.metadata.time)

    if INCLUDE_SNIPS:
        unitLength = data.metadata.units.length
        unitMass = data.metadata.units.mass
        unitTime = data.metadata.units.time

        snip_handles = get_snip_handles(path_to_snap, z_max=z_clip)
        for snip_handle in tqdm(snip_handles,
                                desc=f"Analysing snipshots",
                                disable=SILENT_PROGRESSBAR):

            # Open the snipshot file from the I/O stream.
            # Cannot use Swiftsimio, since it automatically looks for PartType0,
            # which is not included in snipshot outputs.
            with h5py.File(snip_handle, 'r') as f:
                bh_positions = f['/PartType5/Coordinates'][...]
                bh_masses = f['/PartType5/SubgridMasses'][...]
                bh_ids = f['/PartType5/ParticleIDs'][...]
                redshift = f['Header'].attrs['Redshift'][0]
                time = f['Header'].attrs['Time'][0]
                a = f['Header'].attrs['Scale-factor'][0]

            if BH_LOCK_ID:
                central_bh_index = np.where(
                    bh_ids == central_bh_id_target.v)[0]
            else:
                raise ValueError((
                    "Trying to lock the central BH to the halo centre of potential "
                    "in snipshots, which do not have corresponding halo catalogues. "
                    "Please, lock the central BH to the particle ID found at z = 0. "
                    "While this set-up is not yet implemented, it would be possible to "
                    "interpolate the position of the CoP between snapshots w.r.t. cosmic time."
                ))

            # This time we need to manually convert to physical coordinates and assign units.
            # Catalogue-dependent quantities are not appended.
            central_bh['x'].append(
                unyt.unyt_quantity(bh_positions[central_bh_index, 0] / a,
                                   unitLength))
            central_bh['y'].append(
                unyt.unyt_quantity(bh_positions[central_bh_index, 1] / a,
                                   unitLength))
            central_bh['z'].append(
                unyt.unyt_quantity(bh_positions[central_bh_index, 2] / a,
                                   unitLength))
            # central_bh['dx'].append(np.nan)
            # central_bh['dy'].append(np.nan)
            # central_bh['dz'].append(np.nan)
            # central_bh['dr'].append(np.nan)
            central_bh['mass'].append(
                unyt.unyt_quantity(bh_masses[central_bh_index], unitMass))
            # central_bh['m500c'].append(np.nan)
            central_bh['id'].append(
                unyt.unyt_quantity(bh_ids[central_bh_index],
                                   unyt.dimensionless))
            central_bh['redshift'].append(
                unyt.unyt_quantity(redshift, unyt.dimensionless))
            central_bh['time'].append(unyt.unyt_quantity(time, unitTime))

    # Convert lists to Swiftsimio cosmo arrays
    for key in central_bh:
        central_bh[key] = sw.cosmo_array(central_bh[key]).flatten()
        if not SILENT_PROGRESSBAR:
            print(
                f"Central BH memory [{key}]: {central_bh[key].nbytes / 1024:.3f} kB"
            )

    return central_bh
Ejemplo n.º 6
0
def image_snap(isnap):
    """Main function to image one specified snapshot."""

    print(f"Beginning imaging snapshot {isnap}...")
    stime = time.time()

    plotloc = (args.rootdir +
               f'{args.outdir}/image_pt{args.ptype}_{args.imtype}_'
               f'{args.coda}_')
    if args.cambhbid is not None:
        plotloc = plotloc + f'BH-{args.cambhbid}_'
    if not os.path.isdir(os.path.dirname(plotloc)):
        os.makedirs(os.path.dirname(plotloc))
    if not args.replot_existing and os.path.isfile(
            f'{plotloc}{isnap:04d}.png'):
        print(f"Image {plotloc}{isnap:04d}.png already exists, skipping.")
        return

    snapdir = args.rootdir + f'{args.snap_name}_{isnap:04d}.hdf5'

    mask = sw.mask(snapdir)

    # Read metadata
    print("Read metadata...")
    boxsize = max(mask.metadata.boxsize.value)

    ut = hd.read_attribute(snapdir, 'Units', 'Unit time in cgs (U_t)')[0]
    um = hd.read_attribute(snapdir, 'Units', 'Unit mass in cgs (U_M)')[0]
    time_int = hd.read_attribute(snapdir, 'Header', 'Time')[0]
    aexp_factor = hd.read_attribute(snapdir, 'Header', 'Scale-factor')[0]
    zred = hd.read_attribute(snapdir, 'Header', 'Redshift')[0]
    num_part = hd.read_attribute(snapdir, 'Header', 'NumPart_Total')

    time_gyr = time_int * ut / (3600 * 24 * 365.24 * 1e9)
    mdot_factor = (um / 1.989e33) / (ut / (3600 * 24 * 365.24))

    # -----------------------
    # Snapshot-specific setup
    # -----------------------

    # Camera position
    camPos = None
    if vr_halo >= 0:
        print("Reading camera position from VR catalogue...")
        vr_file = args.rootdir + f'vr_{isnap:04d}.hdf5'
        camPos = hd.read_data(vr_file, 'MinimumPotential/Coordinates')

    elif args.varpos is not None:
        print("Find camera position...")
        if len(args.varpos) != 6:
            print("Need 6 arguments for moving box")
            set_trace()
        camPos = np.array([
            args.varpos[0] + args.varpos[3] * time_gyr,
            args.varpos[1] + args.varpos[4] * time_gyr,
            args.varpos[2] + args.varpos[5] * time_gyr
        ])
        print(camPos)
        camPos *= aexp_factor

    elif args.campos is not None:
        camPos = np.array(args.campos) * aexp_factor
    elif args.campos_phys is not None:
        camPos = np.array(args.campos)

    elif args.cambhid is not None:
        all_bh_ids = hd.read_data(snapdir, 'PartType5/ParticleIDs')
        args.cambh = np.nonzero(all_bh_ids == args.cambhid)[0]
        if len(args.cambh) == 0:
            print(f"BH ID {args.cambhid} does not exist, skipping.")
            return

        if len(args.cambh) != 1:
            print(f"Could not unambiguously find BH ID '{args.cambhid}'!")
            set_trace()
        args.cambh = args.cambh[0]

    if args.cambh is not None and camPos is None:
        camPos = hd.read_data(snapdir,
                              'PartType5/Coordinates',
                              read_index=args.cambh) * aexp_factor
        args.hsml = hd.read_data(
            snapdir, 'PartType5/SmoothingLengths',
            read_index=args.cambh) * aexp_factor * kernel_gamma

    elif camPos is None:
        print("Setting camera position to box centre...")
        camPos = np.array([0.5, 0.5, 0.5]) * boxsize * aexp_factor

    # Image size conversion, if necessary
    if not args.propersize:
        args.imsize = args.realimsize * aexp_factor
        args.zsize = args.realzsize * aexp_factor
    else:
        args.imsize = args.realimsize
        args.zsize = args.realzsize

    max_sel = 1.2 * np.sqrt(3) * max(args.imsize, args.zsize)
    extent = np.array([-1, 1, -1, 1]) * args.imsize

    # Set up loading region
    if max_sel < boxsize * aexp_factor / 2:

        load_region = np.array(
            [[camPos[0] - args.imsize * 1.2, camPos[0] + args.imsize * 1.2],
             [camPos[1] - args.imsize * 1.2, camPos[1] + args.imsize * 1.2],
             [camPos[2] - args.zsize * 1.2, camPos[2] + args.zsize * 1.2]])
        load_region = sw.cosmo_array(load_region / aexp_factor, "Mpc")
        mask.constrain_spatial(load_region)
        data = sw.load(snapdir, mask=mask)
    else:
        data = sw.load(snapdir)

    pt_names = ['gas', 'dark_matter', None, None, 'stars', 'black_holes']
    datapt = getattr(data, pt_names[args.ptype])

    pos = datapt.coordinates.value * aexp_factor

    # Next bit does periodic wrapping
    def flip_dim(idim):
        full_box_phys = boxsize * aexp_factor
        half_box_phys = boxsize * aexp_factor / 2
        if camPos[idim] < min(max_sel, half_box_phys):
            ind_high = np.nonzero(pos[:, idim] > half_box_phys)[0]
            pos[ind_high, idim] -= full_box_phys
        elif camPos[idim] > max(full_box_phys - max_sel, half_box_phys):
            ind_low = np.nonzero(pos[:, idim] < half_box_phys)[0]
            pos[ind_low, idim] += full_box_phys

    for idim in range(3):
        print(f"Periodic wrapping in dimension {idim}...")
        flip_dim(idim)

    rad = np.linalg.norm(pos - camPos[None, :], axis=1)
    ind_sel = np.nonzero(rad < max_sel)[0]
    pos = pos[ind_sel, :]

    # Read BH properties, if they exist
    if num_part[5] > 0 and not args.nobh:
        bh_hsml = (hd.read_data(snapdir, 'PartType5/SmoothingLengths') *
                   aexp_factor)
        bh_pos = hd.read_data(snapdir, 'PartType5/Coordinates') * aexp_factor
        bh_mass = hd.read_data(snapdir, 'PartType5/SubgridMasses') * 1e10
        bh_maccr = (hd.read_data(snapdir, 'PartType5/AccretionRates') *
                    mdot_factor)
        bh_id = hd.read_data(snapdir, 'PartType5/ParticleIDs')
        bh_nseed = hd.read_data(snapdir, 'PartType5/CumulativeNumberOfSeeds')
        bh_ft = hd.read_data(snapdir, 'PartType5/FormationScaleFactors')
        print(f"Max BH mass: {np.log10(np.max(bh_mass))}")

    else:
        bh_mass = None  # Dummy value

    # Read the appropriate 'mass' quantity
    if args.ptype == 0 and args.imtype == 'sfr':
        mass = datapt.star_formation_rates[ind_sel]
        mass.convert_to_units(unyt.Msun / unyt.yr)
        mass = np.clip(mass.value, 0, None)  # Don't care about last SFR aExp
    else:
        mass = datapt.masses[ind_sel]
        mass.convert_to_units(unyt.Msun)
        mass = mass.value

    if args.ptype == 0:
        hsml = (datapt.smoothing_lengths.value[ind_sel] * aexp_factor *
                kernel_gamma)
    elif fixedSmoothingLength > 0:
        hsml = np.zeros(mass.shape[0], dtype=np.float32) + fixedSmoothingLength
    else:
        hsml = None

    if args.imtype == 'temp':
        quant = datapt.temperatures.value[ind_sel]
    elif args.imtype == 'diffusion_parameters':
        quant = datapt.diffusion_parameters.value[ind_sel]
    else:
        quant = mass

    # Read quantities for gri computation if necessary
    if args.ptype == 4 and args.imtype == 'gri':
        m_init = datapt.initial_masses.value[ind_sel] * 1e10  # in M_sun
        z_star = datapt.metal_mass_fractions.value[ind_sel]
        sft = datapt.birth_scale_factors.value[ind_sel]

        age_star = (time_gyr - hy.aexp_to_time(sft, time_type='age')) * 1e9
        age_star = np.clip(age_star, 0, None)  # Avoid rounding issues

        lum_g = et.imaging.stellar_luminosity(m_init, z_star, age_star, 'g')
        lum_r = et.imaging.stellar_luminosity(m_init, z_star, age_star, 'r')
        lum_i = et.imaging.stellar_luminosity(m_init, z_star, age_star, 'i')

    # ---------------------
    # Generate actual image
    # ---------------------

    xBase = np.zeros(3, dtype=np.float32)
    yBase = np.copy(xBase)
    zBase = np.copy(xBase)

    if args.imtype == 'gri':
        image_weight_all_g, image_quant, hsml = ir.make_sph_image_new_3d(
            pos,
            lum_g,
            lum_g,
            hsml,
            DesNgb=desNGB,
            imsize=args.numpix,
            zpix=1,
            boxsize=args.imsize,
            CamPos=camPos,
            CamDir=camDir,
            ProjectionPlane=projectionPlane,
            verbose=True,
            CamAngle=[0, 0, rho],
            rollMode=0,
            edge_on=edge_on,
            treeAllocFac=10,
            xBase=xBase,
            yBase=yBase,
            zBase=zBase,
            return_hsml=True)
        image_weight_all_r, image_quant = ir.make_sph_image_new_3d(
            pos,
            lum_r,
            lum_r,
            hsml,
            DesNgb=desNGB,
            imsize=args.numpix,
            zpix=1,
            boxsize=args.imsize,
            CamPos=camPos,
            CamDir=camDir,
            ProjectionPlane=projectionPlane,
            verbose=True,
            CamAngle=[0, 0, rho],
            rollMode=0,
            edge_on=edge_on,
            treeAllocFac=10,
            xBase=xBase,
            yBase=yBase,
            zBase=zBase,
            return_hsml=False)
        image_weight_all_i, image_quant = ir.make_sph_image_new_3d(
            pos,
            lum_i,
            lum_i,
            hsml,
            DesNgb=desNGB,
            imsize=args.numpix,
            zpix=1,
            boxsize=args.imsize,
            CamPos=camPos,
            CamDir=camDir,
            ProjectionPlane=projectionPlane,
            verbose=True,
            CamAngle=[0, 0, rho],
            rollMode=0,
            edge_on=edge_on,
            treeAllocFac=10,
            xBase=xBase,
            yBase=yBase,
            zBase=zBase,
            return_hsml=False)

        map_maas_g = -5 / 2 * np.log10(image_weight_all_g[:, :, 1] +
                                       1e-15) + 5 * np.log10(
                                           180 * 3600 / np.pi) + 25
        map_maas_r = -5 / 2 * np.log10(image_weight_all_r[:, :, 1] +
                                       1e-15) + 5 * np.log10(
                                           180 * 3600 / np.pi) + 25
        map_maas_i = -5 / 2 * np.log10(image_weight_all_i[:, :, 1] +
                                       1e-15) + 5 * np.log10(
                                           180 * 3600 / np.pi) + 25

    else:
        image_weight_all, image_quant = ir.make_sph_image_new_3d(
            pos,
            mass,
            quant,
            hsml,
            DesNgb=desNGB,
            imsize=args.numpix,
            zpix=1,
            boxsize=args.imsize,
            CamPos=camPos,
            CamDir=camDir,
            ProjectionPlane=projectionPlane,
            verbose=True,
            CamAngle=[0, 0, rho],
            rollMode=0,
            edge_on=edge_on,
            treeAllocFac=10,
            xBase=xBase,
            yBase=yBase,
            zBase=zBase,
            zrange=[-args.zsize, args.zsize])

        # Extract surface density in M_sun [/yr] / kpc^2
        sigma = np.log10(image_weight_all[:, :, 1] + 1e-15) - 6
        if args.ptype == 0 and args.imtype in ['temp']:
            tmap = np.log10(image_quant[:, :, 1])
        elif args.ptype == 0 and args.imtype in ['diffusion_parameters']:
            tmap = image_quant[:, :, 1]

    # -----------------
    # Save image data
    # -----------------

    if save_maps:
        maploc = plotloc + f'{isnap:04d}.hdf5'

        if args.imtype == 'gri' and args.ptype == 4:
            hd.write_data(maploc, 'g_maas', map_maas_g, new=True)
            hd.write_data(maploc, 'r_maas', map_maas_r)
            hd.write_data(maploc, 'i_maas', map_maas_i)
        else:
            hd.write_data(maploc, 'Sigma', sigma, new=True)
            if args.ptype == 0 and args.imtype == 'temp':
                hd.write_data(maploc, 'Temperature', tmap)
            elif args.ptype == 0 and args.imtype == 'diffusion_parameters':
                hd.write_data(maploc, 'DiffusionParameters', tmap)

        hd.write_data(maploc, 'Extent', extent)
        hd.write_attribute(maploc, 'Header', 'CamPos', camPos)
        hd.write_attribute(maploc, 'Header', 'ImSize', args.imsize)
        hd.write_attribute(maploc, 'Header', 'NumPix', args.numpix)
        hd.write_attribute(maploc, 'Header', 'Redshift', 1 / aexp_factor - 1)
        hd.write_attribute(maploc, 'Header', 'AExp', aexp_factor)
        hd.write_attribute(maploc, 'Header', 'Time', time_gyr)

        if bh_mass is not None:
            hd.write_data(maploc,
                          'BH_pos',
                          bh_pos - camPos[None, :],
                          comment='Relative position of BHs')
            hd.write_data(maploc,
                          'BH_mass',
                          bh_mass,
                          comment='Subgrid mass of BHs')
            hd.write_data(
                maploc,
                'BH_maccr',
                bh_maccr,
                comment='Instantaneous BH accretion rate in M_sun/yr')
            hd.write_data(maploc,
                          'BH_id',
                          bh_id,
                          comment='Particle IDs of BHs')
            hd.write_data(maploc,
                          'BH_nseed',
                          bh_nseed,
                          comment='Number of seeds in each BH')
            hd.write_data(maploc,
                          'BH_aexp',
                          bh_ft,
                          comment='Formation scale factor of each BH')

    # -------------
    # Plot image...
    # -------------

    if not args.noplot:

        print("Obtained image, plotting...")
        fig = plt.figure(figsize=(args.inch, args.inch))
        ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
        plt.sca(ax)

        # Option I: we have really few particles. Plot them individually:
        if pos.shape[0] < 32:
            plt.scatter(pos[:, 0] - camPos[0],
                        pos[:, 1] - camPos[1],
                        color='white')

        else:
            # Main plotting regime

            # Case A: gri image -- very different from rest
            if args.ptype == 4 and args.imtype == 'gri':

                vmin = -args.scale[0] + np.array([-0.5, -0.25, 0.0])
                vmax = -args.scale[1] + np.array([-0.5, -0.25, 0.0])

                clmap_rgb = np.zeros((args.numpix, args.numpix, 3))
                clmap_rgb[:, :, 2] = np.clip(
                    ((-map_maas_g) - vmin[0]) / ((vmax[0] - vmin[0])), 0, 1)
                clmap_rgb[:, :, 1] = np.clip(
                    ((-map_maas_r) - vmin[1]) / ((vmax[1] - vmin[1])), 0, 1)
                clmap_rgb[:, :, 0] = np.clip(
                    ((-map_maas_i) - vmin[2]) / ((vmax[2] - vmin[2])), 0, 1)

                im = plt.imshow(clmap_rgb,
                                extent=extent,
                                aspect='equal',
                                interpolation='nearest',
                                origin='lower',
                                alpha=1.0)

            else:

                # Establish image scaling
                if not args.absscale:
                    ind_use = np.nonzero(sigma > 1e-15)
                    vrange = np.percentile(sigma[ind_use], args.scale)
                else:
                    vrange = args.scale
                print(f'Sigma range: {vrange[0]:.4f} -- {vrange[1]:.4f}')

                # Case B: temperature/diffusion parameter image
                if (args.ptype == 0
                        and args.imtype in ['temp', 'diffusion_parameters']
                        and not args.no_double_image):
                    if args.imtype == 'temp':
                        cmap = None
                    elif args.imtype == 'diffusion_parameters':
                        cmap = cmocean.cm.haline
                    clmap_rgb = ir.make_double_image(
                        sigma,
                        tmap,
                        percSigma=vrange,
                        absSigma=True,
                        rangeQuant=args.quantrange,
                        cmap=cmap)

                    im = plt.imshow(clmap_rgb,
                                    extent=extent,
                                    aspect='equal',
                                    interpolation='nearest',
                                    origin='lower',
                                    alpha=1.0)

                else:
                    # Standard sigma images
                    if args.ptype == 0:
                        if args.imtype == 'hi':
                            cmap = plt.cm.bone
                        elif args.imtype == 'sfr':
                            cmap = plt.cm.magma
                        elif args.imtype == 'diffusion_parameters':
                            cmap = cmocean.cm.haline
                        else:
                            cmap = plt.cm.inferno

                    elif args.ptype == 1:
                        cmap = plt.cm.Greys_r
                    elif args.ptype == 4:
                        cmap = plt.cm.bone

                    if args.no_double_image:
                        plotquant = tmap
                        vmin, vmax = args.quantrange[0], args.quantrange[1]
                    else:
                        plotquant = sigma
                        vmin, vmax = vrange[0], vrange[1]

                    im = plt.imshow(plotquant,
                                    cmap=cmap,
                                    extent=extent,
                                    vmin=vmin,
                                    vmax=vmax,
                                    origin='lower',
                                    interpolation='nearest',
                                    aspect='equal')

        # Plot BHs if desired:
        if show_bhs and bh_mass is not None:

            if args.bh_file is not None:
                bh_inds = np.loadtxt(args.bh_file, dtype=int)
            else:
                bh_inds = np.arange(bh_pos.shape[0])

            ind_show = np.nonzero(
                (np.abs(bh_pos[bh_inds, 0] - camPos[0]) < args.imsize)
                & (np.abs(bh_pos[bh_inds, 1] - camPos[1]) < args.imsize)
                & (np.abs(bh_pos[bh_inds, 2] - camPos[2]) < args.zsize)
                & (bh_ft[bh_inds] >= args.bh_ftrange[0])
                & (bh_ft[bh_inds] <= args.bh_ftrange[1])
                & (bh_mass[bh_inds] >= 10.0**args.bh_mrange[0])
                & (bh_mass[bh_inds] <= 10.0**args.bh_mrange[1]))[0]
            ind_show = bh_inds[ind_show]

            if args.bh_quant == 'mass':
                sorter = np.argsort(bh_mass[ind_show])
                sc = plt.scatter(bh_pos[ind_show[sorter], 0] - camPos[0],
                                 bh_pos[ind_show[sorter], 1] - camPos[1],
                                 marker='o',
                                 c=np.log10(bh_mass[ind_show[sorter]]),
                                 edgecolor='grey',
                                 vmin=5.0,
                                 vmax=args.bh_mmax,
                                 s=5.0,
                                 linewidth=0.2)
                bticks = np.linspace(5.0, args.bh_mmax, num=6, endpoint=True)
                blabel = r'log$_{10}$ ($m_\mathrm{BH}$ [M$_\odot$])'

            elif args.bh_quant == 'formation':
                sorter = np.argsort(bh_ft[ind_show])
                sc = plt.scatter(bh_pos[ind_show[sorter], 0] - camPos[0],
                                 bh_pos[ind_show[sorter], 1] - camPos[1],
                                 marker='o',
                                 c=bh_ft[ind_show[sorter]],
                                 edgecolor='grey',
                                 vmin=0,
                                 vmax=1.0,
                                 s=5.0,
                                 linewidth=0.2)
                bticks = np.linspace(0.0, 1.0, num=6, endpoint=True)
                blabel = 'Formation scale factor'

            if args.bhind:
                for ibh in ind_show[sorter]:
                    c = plt.cm.viridis(
                        (np.log10(bh_mass[ibh]) - 5.0) / (args.bh_mmax - 5.0))
                    plt.text(bh_pos[ibh, 0] - camPos[0] + args.imsize / 200,
                             bh_pos[ibh, 1] - camPos[1] + args.imsize / 200,
                             f'{ibh}',
                             color=c,
                             fontsize=4,
                             va='bottom',
                             ha='left')

            if args.draw_hsml:
                phi = np.arange(0, 2.01 * np.pi, 0.01)
                plt.plot(args.hsml * np.cos(phi),
                         args.hsml * np.sin(phi),
                         color='white',
                         linestyle=':',
                         linewidth=0.5)

            # Add colour bar for BH masses
            if args.imtype != 'sfr':
                ax2 = fig.add_axes([0.6, 0.07, 0.35, 0.02])
                ax2.set_xticks([])
                ax2.set_yticks([])
                cbar = plt.colorbar(sc,
                                    cax=ax2,
                                    orientation='horizontal',
                                    ticks=bticks)
                cbar.ax.tick_params(labelsize=8)
                fig.text(0.775,
                         0.1,
                         blabel,
                         rotation=0.0,
                         va='bottom',
                         ha='center',
                         color='white',
                         fontsize=8)

        # Done with main image, some embellishments...
        plt.sca(ax)
        plt.text(-0.045 / 0.05 * args.imsize,
                 0.045 / 0.05 * args.imsize,
                 'z = {:.3f}'.format(1 / aexp_factor - 1),
                 va='center',
                 ha='left',
                 color='white')
        plt.text(-0.045 / 0.05 * args.imsize,
                 0.041 / 0.05 * args.imsize,
                 't = {:.3f} Gyr'.format(time_gyr),
                 va='center',
                 ha='left',
                 color='white',
                 fontsize=8)

        plot_bar()

        # Plot colorbar for SFR if appropriate
        if args.ptype == 0 and args.imtype == 'sfr':
            ax2 = fig.add_axes([0.6, 0.07, 0.35, 0.02])
            ax2.set_xticks([])
            ax2.set_yticks([])

            scc = plt.scatter([-1e10], [-1e10],
                              c=[0],
                              cmap=plt.cm.magma,
                              vmin=vrange[0],
                              vmax=vrange[1])
            cbar = plt.colorbar(scc,
                                cax=ax2,
                                orientation='horizontal',
                                ticks=np.linspace(np.floor(vrange[0]),
                                                  np.ceil(vrange[1]),
                                                  5,
                                                  endpoint=True))
            cbar.ax.tick_params(labelsize=8)
            fig.text(
                0.775,
                0.1,
                r'log$_{10}$ ($\Sigma_\mathrm{SFR}$ [M$_\odot$ yr$^{-1}$ kpc$^{-2}$])',
                rotation=0.0,
                va='bottom',
                ha='center',
                color='white',
                fontsize=8)

        ax.set_xlabel(r'$\Delta x$ [pMpc]')
        ax.set_ylabel(r'$\Delta y$ [pMpc]')

        ax.set_xlim((-args.imsize, args.imsize))
        ax.set_ylim((-args.imsize, args.imsize))

        plt.savefig(plotloc + str(isnap).zfill(4) + '.png',
                    dpi=args.numpix / args.inch)
        plt.close()

    print(f"Finished snapshot {isnap} in {(time.time() - stime):.2f} sec.")
    print(f"Image saved in {plotloc}{isnap:04d}.png")