def extract_bh_data(snapfile, isim, isnap, args):
    """Extract BH data from one particular file"""

    if isinstance(isim, int):
        pre = f'ID{isim}/S{isnap}/BH/'
    else:
        pre = f'{isim}/S{isnap}/BH/'

    zred = hd.read_attribute(snapfile, 'Header', 'Redshift')[0]
    hd.write_attribute(args.outfile, pre, 'Redshift', zred)
    
    for field in args.bh_fields:
        data = hd.read_data(snapfile, 'PartType5/' + field)
        if data is not None:
            hd.write_data(args.outfile, pre + '/' + field, data)
Exemple #2
0
def add_coda_to_offsets(vr_part_file):
    """Add the coda to particle offsets."""

    num_name = {
        'Haloes': 'NumberOfBoundParticles_Total',
        'Unbound': 'NumberOfUnboundParticles_Total',
        'Groups': 'NumberOfSOParticles_Total'
    }

    for grp in ['Haloes', 'Unbound', 'Groups']:
        offsets = read_data(vr_part_file, f'{grp}/Offsets')
        num_ids = read_attribute(vr_part_file, 'Header', num_name[grp])

        offsets = np.concatenate((offsets, [num_ids]))
        write_data(vr_part_file, f'{grp}/Offsets', offsets)
def extract_vr_data(vrfile, isim, isnap, args):
    """Extract VR data from one particular file"""

    if isinstance(isim, int):
        pre = f'ID{isim}/S{isnap}/'
    else:
        pre = f'{isim}/S{isnap}/'
        
    for field in args.vr_fields:
        data = hd.read_data(vrfile, field[0])
        if data is None:
            print(f"Could not find field '{field[0]}' in VR file '{vrfile}'!")
            #set_trace()

        else:
            hd.write_data(args.outfile, f'{pre}/{field[1]}', data)
Exemple #4
0
def process_snapshot(isim, isnap, args):
    """Process one snapshot."""
    print("")
    print(f"Processing snapshot {isnap}...")
    
    snap_this = Snapshot(f'{args.wdir}{args.vr_name}_{isnap:04d}')
    snap_dmo = Snapshot(f'{args.dmo_dir}{args.vr_name}_{isnap:04d}',
                        sim_type='dm-only')

    # Find the best match in the DMO snapshot for each halo in this sim
    print("\nMatching from this to DMO...")
    match_in_dmo, gate_this_dmo, match_frac_in_dmo = (
        match_haloes(snap_this, snap_dmo))

    # ... and do the same in reverse
    print("\nMatching from DMO to this...")
    match_in_this, gate_dmo_this, match_frac_in_this = (
        match_haloes(snap_dmo, snap_this))

    # Reject matches that are not bijective
    ind_non_bijective = np.nonzero((match_in_dmo < 0) |
                                   (match_in_this[match_in_dmo] !=
                                    np.arange(snap_this.n_haloes)))[0]
    match_in_dmo[ind_non_bijective] = -1
    ind_not_matched = np.nonzero(match_in_dmo < 0)[0]
    
    # Write out results
    vr_file_this = f'{args.wdir}{args.vr_name}_{isnap:04d}.hdf5'
    vr_file_dmo = f'{args.dmo_dir}{args.vr_name}_{isnap:04d}.hdf5'

    hd.write_data(vr_file_this, 'MatchInDMO/Haloes', match_in_dmo)
    hd.write_data(vr_file_this, 'MatchInDMO/MatchFractionInDMO',
                  match_frac_in_dmo)
    match_frac_in_this_aligned = match_frac_in_this[match_in_dmo]
    match_frac_in_this_aligned[ind_not_matched] = np.nan
    hd.write_data(vr_file_this, 'MatchInDMO/MatchFractionInThis',
                  match_frac_in_this_aligned)
    
    for iset in ['M200crit', 'Masses', 'MaximumCircularVelocities']:
        data_dmo = hd.read_data(vr_file_dmo, iset)
        data_dmo_aligned = data_dmo[match_in_dmo]
        data_dmo_aligned[ind_not_matched] = np.nan
        
        hd.write_data(vr_file_this, f'MatchInDMO/{iset}', data_dmo_aligned)
Exemple #5
0
def convert_sfr(ascii_file, hdf5_file, sim_type='swift', unit_sfr=1.022690e-2):

    #hdf5_file = '.'.join(ascii_file.split('.')[:-1]) + '.hdf5'

    stime = time.time()
    print(f"Reading file '{ascii_file}'...", end='', flush=True)
    sfrdata = ascii.read(ascii_file)
    print(f"done (took {(time.time()-stime):.3f} sec.)")

    if sim_type.lower() == 'swift':

        aexp = np.array(sfrdata['col3'])
        zred = np.array(sfrdata['col4'])
        sfr = np.array(sfrdata['col8']) * unit_sfr

    else:  # Gadget-style SFR log
        aexp = np.array(sfrdata['col1'])
        zred = 1 / aexp - 1
        sfr = np.array(sfrdata['col3'])

    hd.write_data(hdf5_file, 'aExp', aexp)
    hd.write_data(hdf5_file, 'Redshift', zred)
    hd.write_data(hdf5_file, 'SFR', sfr)
Exemple #6
0
def make_plot(args, vr_data, iplot, isnap, iibh, ibh):
    """Make one specific plot.

    Parameters
    ----------
    args : object
        Configuration parameters read from the arg parser
    vr_data : dict
        Relevant VR catalogue data for all BHs to be plotted
    iplot : int
        Tuple containing information about the specific plot to make.
        Content: [x-quant, y-quant, color-quant]
    isnap : int
        Snapshot index of this plot; only relevant for output naming.
    iibh : int, optional
        BH index of target BH in vr_data. If None (default), no BH is 
        highlighted especially.
    ibh : int, optional
        Full BH-ID of target BH (only used for output naming).
    """
    plotloc = (f'{args.wdir}{args.plot_prefix}_{iplot[0]}-{iplot[1]}-'
               f'{iplot[2]}_BH-{ibh}_snap-{isnap}.png')
    fig = plt.figure(figsize=(5.5, 4.5))

    # To enable the HTML link map, we must be able to reconstruct the
    # location of each point on the plot. This is easier with an explicitly
    # constructed axes frame ([x_off, y_off, x_width, y_width])
    ax = fig.add_axes([0.15, 0.15, 0.67, 0.8])

    xr, yr = plot_ranges[iplot[0]], plot_ranges[iplot[1]]
    ax.set_xlim(xr)
    ax.set_ylim(yr)
    ax.set_xlabel(ax_labels[iplot[0]])
    ax.set_ylabel(ax_labels[iplot[1]])
    vmin, vmax = plot_ranges[iplot[2]]

    # Extract relevant quantities
    xquant = get_vrquant(vr_data, iplot[0])
    yquant = get_vrquant(vr_data, iplot[1])
    cquant = get_vrquant(vr_data, iplot[2])

    xquant_plt = np.copy(xquant)
    yquant_plt = np.copy(yquant)

    if args.show_median:
        plot_distribution(xquant,
                          yquant,
                          xrange=xr,
                          uncertainty=True,
                          scatter=False,
                          plot_at_median=True,
                          color='grey',
                          nbins=5,
                          dashed_below=5)

    for iixbh in range(len(xquant)):

        # Special treatment for (halo of) to-be-highlighted BH
        if iixbh == iibh and iibh is not None:
            s, edgecolor = 50.0, 'red'
        else:
            s, edgecolor = 15.0, 'none'

        marker = 'o'

        ixquant, iyquant = xquant[iixbh], yquant[iixbh]
        if ((ixquant < xr[0] or ixquant > xr[1])
                and (iyquant < yr[0] or iyquant > yr[1])):
            continue

        if ixquant < xr[0]:
            ixquant = xr[0] + (xr[1] - xr[0]) * 0.02
            marker = r'$\mathbf{\leftarrow}$'
            s *= 3
        elif ixquant > xr[1]:
            ixquant = xr[1] - (xr[1] - xr[0]) * 0.02
            marker = r'$\mathbf{\rightarrow}$'
            s *= 3
        if iyquant < yr[0]:
            iyquant = yr[0] + (yr[1] - yr[0]) * 0.02
            marker = r'$\mathbf{\downarrow}$'
            s *= 3
        elif iyquant > yr[1]:
            iyquant = yr[1] - (yr[1] - yr[0]) * 0.02
            marker = r'$\mathbf{\uparrow}$'
            s *= 3

        icquant = np.clip(cquant[iixbh], plot_ranges[iplot[2]][0],
                          plot_ranges[iplot[2]][1])

        if ixquant * 0 != 0: continue
        if iyquant * 0 != 0 and iplot[1] != 'MBH': set_trace()
        if icquant * 0 != 0:
            icquant = 'grey'
        #print(f'x={ixquant}, y={iyquant}, c={icquant}')
        sc = plt.scatter([ixquant], [iyquant],
                         c=[icquant],
                         cmap=plt.cm.viridis,
                         marker=marker,
                         vmin=vmin,
                         vmax=vmax,
                         s=s,
                         edgecolor=edgecolor,
                         zorder=100)

        # Feed possible clipping changes back to array, for outputting.
        xquant_plt[iixbh], yquant_plt[iixbh] = ixquant, iyquant

    # Colour bar on the right-hand side
    ax2 = fig.add_axes([0.82, 0.15, 0.04, 0.8])
    ax2.set_yticks([])
    ax2.set_xticks([])
    cbar = plt.colorbar(sc, cax=ax2, orientation='vertical')
    fig.text(0.99,
             0.5,
             ax_labels[iplot[2]],
             color='black',
             rotation=90.0,
             fontsize=10,
             ha='right',
             va='center')

    plt.show
    plt.savefig(plotloc, dpi=200, transparent=False)
    plt.close('all')

    # Final bit: store normalised data points in HDF5 file
    if ibh is not None:
        return

    imx = (xquant_plt - xr[0]) / (xr[1] - xr[0])
    imx = imx * 0.67 + 0.15

    imy = (yquant_plt - yr[0]) / (yr[1] - yr[0])
    imy = (1 - (imy * 0.8 + 0.15)) * (4.5 / 5.5)

    hd.write_data(args.sim_pointdata_loc,
                  f'S{isnap}/{iplot[0]}-{iplot[1]}/xpt',
                  imx,
                  comment='Normalised x coordinates of points.')
    hd.write_data(args.sim_pointdata_loc,
                  f'S{isnap}/{iplot[0]}-{iplot[1]}/ypt',
                  imy,
                  comment='Normalised y coordinates of points.')
Exemple #7
0
def process_sim(args, isim, have_full_sim_dir):
    """Generate the images for one particular simulation."""

    if have_full_sim_dir:
        args.wdir = isim
    else:
        args.wdir = xl.get_sim_dir(args.base_dir, isim)

    # Name of the input BH data catalogue
    args.catloc = f'{args.wdir}{args.bh_file}'

    # Find BHs we are intereste in, load data (of internal VR match)
    select_list = [["Halo_MStar", '>=', args.halo_mstar_range[0]],
                   ["Halo_MStar", '<', args.halo_mstar_range[1]],
                   ["DMO_M200c", '>=', args.halo_m200_range[0]],
                   ["DMO_M200c", '<', args.halo_m200_range[1]]]
    if not args.include_subdominant_bhs:
        select_list.append(['Flag_MostMassiveInHalo', '==', 1])
    if not args.include_satellites:
        select_list.append(['HaloTypes', '==', 10])

    if args.bh_mass_range is not None:
        zreds = hd.read_data(args.wdir + args.bh_file, 'Redshifts')
        best_index = np.argmin(np.abs(zreds - args.bh_selection_redshift))
        print(f"Best index for redshift {args.bh_selection_redshift} is "
              f"{best_index}.")

        # Subgrid masses are in 10^10 M_sun, so need to adjust selection range
        select_list.append(
            ['SubgridMasses', '>=', args.bh_mass_range[0] / 1e10, best_index])
        select_list.append(
            ['SubgridMasses', '<=', args.bh_mass_range[1] / 1e10, best_index])

    bh_props_list = ['SubgridMasses', 'Redshifts', 'ParticleIDs']
    bh_data, bh_list = xl.lookup_bh_data(args.wdir + args.bh_file,
                                         bh_props_list, select_list)

    if len(bh_list) == 0:
        print("No BHs selected, aborting.")
        return

    args.sim_pointdata_loc = args.wdir + args.plot_prefix + '.hdf5'
    if not os.path.isdir(os.path.dirname(args.sim_pointdata_loc)):
        os.makedirs(os.path.dirname(args.sim_pointdata_loc))
    if os.path.isfile(args.sim_pointdata_loc):
        shutil.move(args.sim_pointdata_loc, args.sim_pointdata_loc + '.old')
    hd.write_data(args.sim_pointdata_loc,
                  'BlackHoleBIDs',
                  bh_list,
                  comment='BIDs of all BHs selected for this simulation.')

    # Go through snapshots
    for iisnap, isnap in enumerate(args.snapshots):

        print("")
        print(f"Processing snapshot {isnap}...")

        # Need to explicitly connect BHs to VR catalogue from this snap
        vr_data = xl.connect_to_galaxies(
            bh_data['ParticleIDs'][bh_list],
            f'{args.wdir}{args.vr_file}_{isnap:04d}',
            extra_props=[('ApertureMeasurements/Projection1/30kpc/'
                          'Stars/HalfMassRadii', 'StellarHalfMassRad'),
                         ('MatchInDMO/M200crit', 'DMO_M200c')])

        # Add subgrid mass of BHs themselves
        ind_bh_cat = np.argmin(
            np.abs(bh_data['Redshifts'] - vr_data['Redshift']))
        print(f"Using BH masses from index {ind_bh_cat}")
        vr_data['BH_SubgridMasses'] = (
            bh_data['SubgridMasses'][bh_list, ind_bh_cat] * 1e10)

        n_good_m200 = np.count_nonzero(vr_data['DMO_M200c'] * 0 == 0)
        print(f"Have {n_good_m200} BHs with DMO_M200c, out of "
              f"{len(bh_list)}.")

        # Make plots for each BH, and "general" overview plot
        print("Overview plot")
        generate_vr_plots(args, vr_data, isnap)

        if not args.summary_only:
            for iibh, ibh in enumerate(bh_list):
                print(f"Make plot for BH-BID {ibh} ({iibh}/{len(bh_list)})")
                generate_vr_plots(args, vr_data, isnap, iibh, ibh)

        print("Done!")
Exemple #8
0
def transcribe_data(data_list,
                    vrfile_in,
                    outfile,
                    kind='simple',
                    form='scalar',
                    mixed_source=False):
    """Transcribe data sets.

    Parameters
    ----------
    data_list : tuple
        A list of the transcription keys to process. Each key is a tuple
        of (VR_name, Out_name, Comment, Conversion_Factor, Type_list).
    vrfile_in : str
        The VR file to transcribe data from.
    outfile : str
        The output file to store transcribed data in.
    kind : str
        The kind of transcription we are doing. Options are
            - 'main' --> main data transcription
            - 'profiles' --> transcribe profile data
            - 'apertures' --> transcribe aperture measurements
    form : str
        Form of data elements. Options are
            - 'scalar' --> simple scalar data quantity
            - '3darray' --> transcribe 3d array quantities
            - '3x3matrix' --> transcribe 3x3 matrix quantities
    mixed_source : bool, optional
        If True, index 6 specifies the source VR file and 'vrfile' is 
        assumed to be the common base instead.
    """
    for ikey in data_list:
        if len(ikey) < 5: set_trace()
        #if len(ikey) > 5 and kind != 'apertures': set_trace()

        # Deal with possibility of 'None' in Type_list (no type iteration)
        if ikey[4] is None:
            types = [None]
        else:
            types = ikey[4]

        # Deal with possibility of mixed-source input
        if mixed_source:
            if len(ikey) < 7:
                print("Need to specify source file in index 6 for "
                      "mixed source transcription!")
            vrfile = vrfile_in + ikey[6]
        else:
            vrfile = vrfile_in

        if not os.path.isfile(vrfile):
            print("Could not find input VR file...")
            set_trace()

        # Some quantities use capital X/Y/Z in VR...
        if ikey[0] in [
                'V?c', 'V?cmbp', 'V?cminpot', '?c', '?cmbp', '?cminpot'
        ]:
            dimsyms = dimsymbols_cap
        else:
            dimsyms = dimsymbols

        # Iterate over aperture types (only relevant for apertures)
        if kind == 'apertures':
            n_proj = 4
            ap_list = aperture_list
        else:
            n_proj = 1
            ap_list = [None]

        for iap in ap_list:
            for iproj in range(n_proj):

                # Some special treatment for apertures
                if kind == 'apertures':
                    if iproj > 0 and ikey[5] is False:
                        break

                    if iproj == 0:
                        ap_prefix = 'Aperture_'
                        ap_outfix = '/'
                    else:
                        ap_prefix = f'Projected_aperture_{iproj}_'
                        ap_outfix = f'/Projection{iproj}/'
                else:
                    ap_prefix = ''
                    ap_outfix = ''

                # Iterate over required types
                for itype in types:

                    # Construct type specifiers in in- and output
                    if itype is None:
                        typefix_in = ''
                        typefix_out = ''
                    else:
                        typefix_in = type_in[itype]
                        typefix_out = type_out[itype]

                    # Adjust comment
                    if itype is None:
                        comment = ikey[2].replace('#', '')
                    else:
                        comment = ikey[2].replace('#', type_self[itype])

                    # Construct the full data set names in in- and output
                    vrname = ikey[0] + typefix_in
                    outname = typefix_out + ikey[1]

                    # Deal with names for special cases
                    if kind == 'profiles':
                        outname = f'Profiles/{outname}'
                    elif kind == 'apertures':
                        vrname = f'{ap_prefix}{vrname}_{iap}_kpc'
                        outname = (f'ApertureMeasurements{ap_outfix}{iap}kpc/'
                                   f'{outname}')

                    # Transcribe data
                    if args.verbose:
                        print(f"{vrname} --> {outname}")

                    if form == '3darray':
                        outdata = np.zeros(
                            (num_haloes, 3), dtype=np.float32) - 1

                        # Load individual dimensions' data sets into output
                        for idim in range(3):
                            vrname_dim = vrname.replace('?', dimsyms[idim])
                            outdata[:, idim] = read_data(vrfile,
                                                         vrname_dim,
                                                         require=True)

                    elif form == '3x3matrix':
                        outdata = np.zeros(
                            (num_haloes, 3, 3), dtype=np.float32) - 1
                        for idim1 in range(3):
                            for idim2 in range(3):
                                vrname_dim = (vrname.replace(
                                    '?', dimsyms[idim1]).replace(
                                        '*', dimsyms[idim2]))
                                outdata[:, idim1,
                                        idim2] = read_data(vrfile,
                                                           vrname_dim,
                                                           require=True)
                    else:
                        # Standard case (scalar quantity)
                        outdata = read_data(vrfile, vrname, require=True)

                    if ikey[3] is not None:
                        outdata *= ikey[3]

                    write_data(outfile, outname, outdata, comment=comment)
def write_output_file(output_dict, comment_dict, bpart_ids,
                      bpart_first_outputs, gal_props, args):
    """Write the completed arrays to an HDF5 file."""
    print(f"Writing output file '{args.out_dir + args.out_file}...'")

    dataset_list = list(output_dict.keys())

    hd.write_data(args.out_dir + args.out_file, 'ParticleIDs', bpart_ids)
    hd.write_data(args.out_dir + args.out_file, 'FirstIndices',
                  bpart_first_outputs)
    hd.write_data(args.out_dir + args.out_file, 'Redshifts', args.redshifts)
    hd.write_data(args.out_dir + args.out_file, 'Times', args.times)

    if gal_props is not None:

        hd.write_data(args.out_dir + args.out_file,
                      'Haloes',
                      gal_props['Haloes'],
                      comment='Index of the velociraptor halo containing each '
                      f'black hole at redshift {args.vr_zred:.3f}.')
        hd.write_attribute(args.out_dir + args.out_file, 'Haloes',
                           'VR_Snapshot', args.vr_snap)
        hd.write_attribute(args.out_dir + args.out_file, 'Haloes',
                           'VR_Redshift', args.vr_zred)
        hd.write_attribute(args.out_dir + args.out_file, 'Haloes',
                           'VR_ScaleFactor', args.vr_aexp)

        hd.write_data(args.out_dir + args.out_file,
                      'Halo_MStar',
                      gal_props['MStar'],
                      comment='Stellar mass (< 30kpc) of the halo containing '
                      f'the black holes at redshift {args.vr_zred:.3f} '
                      '[M_sun].')
        hd.write_data(args.out_dir + args.out_file,
                      'Halo_SFR',
                      gal_props['SFR'],
                      comment='Star formation rates (< 30kpc) of the halo '
                      'containing the black holes at redshift '
                      f'{args.vr_zred:.3f} [M_sun/yr].')
        hd.write_data(
            args.out_dir + args.out_file,
            'Halo_M200c',
            gal_props['M200c'],
            comment='Halo virial masses (M200c) of the halo containing '
            'the black holes at redshift {args.vr_zred:.3f} '
            '[M_sun].')
        hd.write_data(
            args.out_dir + args.out_file,
            'HaloTypes',
            gal_props['HaloTypes'],
            comment='Types of the haloes containing the black holes at '
            'redshift {args.vr_zred:.3f}. Central haloes have '
            'a value of 10.')
        hd.write_data(
            args.out_dir + args.out_file,
            'Flag_MostMassiveInHalo',
            gal_props['flag_most_massive_bh'],
            comment='1 if this is the most massive black hole in its '
            f'halo at redshift {args.vr_zred}, 0 otherwise.')

        hd.write_data(args.out_dir + args.out_file,
                      'DMO_Haloes',
                      gal_props['DMO_Haloes'],
                      comment='Index of the matched velociraptor halo in the '
                      'corresponding DM-only simulation, at redshift '
                      f'{args.vr_zred} (-1 if no bijective match).')

        hd.write_data(args.out_dir + args.out_file,
                      'DMO_M200c',
                      gal_props['DMO_M200crit'],
                      comment='Virial masses (M200c) of the matched halo in '
                      'the corresponding DM-only simulation at '
                      f'redshift {args.vr_zred}.')

    for dset in dataset_list:
        hd.write_data(args.out_dir + args.out_file,
                      dset,
                      output_dict[dset],
                      comment=comment_dict[dset])

    print("...done!")
Exemple #10
0
def image_snap(isnap):
    """Main function to image one specified snapshot."""

    print(f"Beginning imaging snapshot {isnap}...")
    stime = time.time()

    plotloc = (args.rootdir +
               f'{args.outdir}/image_pt{args.ptype}_{args.imtype}_'
               f'{args.coda}_')
    if args.cambhbid is not None:
        plotloc = plotloc + f'BH-{args.cambhbid}_'
    if not os.path.isdir(os.path.dirname(plotloc)):
        os.makedirs(os.path.dirname(plotloc))
    if not args.replot_existing and os.path.isfile(
            f'{plotloc}{isnap:04d}.png'):
        print(f"Image {plotloc}{isnap:04d}.png already exists, skipping.")
        return

    snapdir = args.rootdir + f'{args.snap_name}_{isnap:04d}.hdf5'

    mask = sw.mask(snapdir)

    # Read metadata
    print("Read metadata...")
    boxsize = max(mask.metadata.boxsize.value)

    ut = hd.read_attribute(snapdir, 'Units', 'Unit time in cgs (U_t)')[0]
    um = hd.read_attribute(snapdir, 'Units', 'Unit mass in cgs (U_M)')[0]
    time_int = hd.read_attribute(snapdir, 'Header', 'Time')[0]
    aexp_factor = hd.read_attribute(snapdir, 'Header', 'Scale-factor')[0]
    zred = hd.read_attribute(snapdir, 'Header', 'Redshift')[0]
    num_part = hd.read_attribute(snapdir, 'Header', 'NumPart_Total')

    time_gyr = time_int * ut / (3600 * 24 * 365.24 * 1e9)
    mdot_factor = (um / 1.989e33) / (ut / (3600 * 24 * 365.24))

    # -----------------------
    # Snapshot-specific setup
    # -----------------------

    # Camera position
    camPos = None
    if vr_halo >= 0:
        print("Reading camera position from VR catalogue...")
        vr_file = args.rootdir + f'vr_{isnap:04d}.hdf5'
        camPos = hd.read_data(vr_file, 'MinimumPotential/Coordinates')

    elif args.varpos is not None:
        print("Find camera position...")
        if len(args.varpos) != 6:
            print("Need 6 arguments for moving box")
            set_trace()
        camPos = np.array([
            args.varpos[0] + args.varpos[3] * time_gyr,
            args.varpos[1] + args.varpos[4] * time_gyr,
            args.varpos[2] + args.varpos[5] * time_gyr
        ])
        print(camPos)
        camPos *= aexp_factor

    elif args.campos is not None:
        camPos = np.array(args.campos) * aexp_factor
    elif args.campos_phys is not None:
        camPos = np.array(args.campos)

    elif args.cambhid is not None:
        all_bh_ids = hd.read_data(snapdir, 'PartType5/ParticleIDs')
        args.cambh = np.nonzero(all_bh_ids == args.cambhid)[0]
        if len(args.cambh) == 0:
            print(f"BH ID {args.cambhid} does not exist, skipping.")
            return

        if len(args.cambh) != 1:
            print(f"Could not unambiguously find BH ID '{args.cambhid}'!")
            set_trace()
        args.cambh = args.cambh[0]

    if args.cambh is not None and camPos is None:
        camPos = hd.read_data(snapdir,
                              'PartType5/Coordinates',
                              read_index=args.cambh) * aexp_factor
        args.hsml = hd.read_data(
            snapdir, 'PartType5/SmoothingLengths',
            read_index=args.cambh) * aexp_factor * kernel_gamma

    elif camPos is None:
        print("Setting camera position to box centre...")
        camPos = np.array([0.5, 0.5, 0.5]) * boxsize * aexp_factor

    # Image size conversion, if necessary
    if not args.propersize:
        args.imsize = args.realimsize * aexp_factor
        args.zsize = args.realzsize * aexp_factor
    else:
        args.imsize = args.realimsize
        args.zsize = args.realzsize

    max_sel = 1.2 * np.sqrt(3) * max(args.imsize, args.zsize)
    extent = np.array([-1, 1, -1, 1]) * args.imsize

    # Set up loading region
    if max_sel < boxsize * aexp_factor / 2:

        load_region = np.array(
            [[camPos[0] - args.imsize * 1.2, camPos[0] + args.imsize * 1.2],
             [camPos[1] - args.imsize * 1.2, camPos[1] + args.imsize * 1.2],
             [camPos[2] - args.zsize * 1.2, camPos[2] + args.zsize * 1.2]])
        load_region = sw.cosmo_array(load_region / aexp_factor, "Mpc")
        mask.constrain_spatial(load_region)
        data = sw.load(snapdir, mask=mask)
    else:
        data = sw.load(snapdir)

    pt_names = ['gas', 'dark_matter', None, None, 'stars', 'black_holes']
    datapt = getattr(data, pt_names[args.ptype])

    pos = datapt.coordinates.value * aexp_factor

    # Next bit does periodic wrapping
    def flip_dim(idim):
        full_box_phys = boxsize * aexp_factor
        half_box_phys = boxsize * aexp_factor / 2
        if camPos[idim] < min(max_sel, half_box_phys):
            ind_high = np.nonzero(pos[:, idim] > half_box_phys)[0]
            pos[ind_high, idim] -= full_box_phys
        elif camPos[idim] > max(full_box_phys - max_sel, half_box_phys):
            ind_low = np.nonzero(pos[:, idim] < half_box_phys)[0]
            pos[ind_low, idim] += full_box_phys

    for idim in range(3):
        print(f"Periodic wrapping in dimension {idim}...")
        flip_dim(idim)

    rad = np.linalg.norm(pos - camPos[None, :], axis=1)
    ind_sel = np.nonzero(rad < max_sel)[0]
    pos = pos[ind_sel, :]

    # Read BH properties, if they exist
    if num_part[5] > 0 and not args.nobh:
        bh_hsml = (hd.read_data(snapdir, 'PartType5/SmoothingLengths') *
                   aexp_factor)
        bh_pos = hd.read_data(snapdir, 'PartType5/Coordinates') * aexp_factor
        bh_mass = hd.read_data(snapdir, 'PartType5/SubgridMasses') * 1e10
        bh_maccr = (hd.read_data(snapdir, 'PartType5/AccretionRates') *
                    mdot_factor)
        bh_id = hd.read_data(snapdir, 'PartType5/ParticleIDs')
        bh_nseed = hd.read_data(snapdir, 'PartType5/CumulativeNumberOfSeeds')
        bh_ft = hd.read_data(snapdir, 'PartType5/FormationScaleFactors')
        print(f"Max BH mass: {np.log10(np.max(bh_mass))}")

    else:
        bh_mass = None  # Dummy value

    # Read the appropriate 'mass' quantity
    if args.ptype == 0 and args.imtype == 'sfr':
        mass = datapt.star_formation_rates[ind_sel]
        mass.convert_to_units(unyt.Msun / unyt.yr)
        mass = np.clip(mass.value, 0, None)  # Don't care about last SFR aExp
    else:
        mass = datapt.masses[ind_sel]
        mass.convert_to_units(unyt.Msun)
        mass = mass.value

    if args.ptype == 0:
        hsml = (datapt.smoothing_lengths.value[ind_sel] * aexp_factor *
                kernel_gamma)
    elif fixedSmoothingLength > 0:
        hsml = np.zeros(mass.shape[0], dtype=np.float32) + fixedSmoothingLength
    else:
        hsml = None

    if args.imtype == 'temp':
        quant = datapt.temperatures.value[ind_sel]
    elif args.imtype == 'diffusion_parameters':
        quant = datapt.diffusion_parameters.value[ind_sel]
    else:
        quant = mass

    # Read quantities for gri computation if necessary
    if args.ptype == 4 and args.imtype == 'gri':
        m_init = datapt.initial_masses.value[ind_sel] * 1e10  # in M_sun
        z_star = datapt.metal_mass_fractions.value[ind_sel]
        sft = datapt.birth_scale_factors.value[ind_sel]

        age_star = (time_gyr - hy.aexp_to_time(sft, time_type='age')) * 1e9
        age_star = np.clip(age_star, 0, None)  # Avoid rounding issues

        lum_g = et.imaging.stellar_luminosity(m_init, z_star, age_star, 'g')
        lum_r = et.imaging.stellar_luminosity(m_init, z_star, age_star, 'r')
        lum_i = et.imaging.stellar_luminosity(m_init, z_star, age_star, 'i')

    # ---------------------
    # Generate actual image
    # ---------------------

    xBase = np.zeros(3, dtype=np.float32)
    yBase = np.copy(xBase)
    zBase = np.copy(xBase)

    if args.imtype == 'gri':
        image_weight_all_g, image_quant, hsml = ir.make_sph_image_new_3d(
            pos,
            lum_g,
            lum_g,
            hsml,
            DesNgb=desNGB,
            imsize=args.numpix,
            zpix=1,
            boxsize=args.imsize,
            CamPos=camPos,
            CamDir=camDir,
            ProjectionPlane=projectionPlane,
            verbose=True,
            CamAngle=[0, 0, rho],
            rollMode=0,
            edge_on=edge_on,
            treeAllocFac=10,
            xBase=xBase,
            yBase=yBase,
            zBase=zBase,
            return_hsml=True)
        image_weight_all_r, image_quant = ir.make_sph_image_new_3d(
            pos,
            lum_r,
            lum_r,
            hsml,
            DesNgb=desNGB,
            imsize=args.numpix,
            zpix=1,
            boxsize=args.imsize,
            CamPos=camPos,
            CamDir=camDir,
            ProjectionPlane=projectionPlane,
            verbose=True,
            CamAngle=[0, 0, rho],
            rollMode=0,
            edge_on=edge_on,
            treeAllocFac=10,
            xBase=xBase,
            yBase=yBase,
            zBase=zBase,
            return_hsml=False)
        image_weight_all_i, image_quant = ir.make_sph_image_new_3d(
            pos,
            lum_i,
            lum_i,
            hsml,
            DesNgb=desNGB,
            imsize=args.numpix,
            zpix=1,
            boxsize=args.imsize,
            CamPos=camPos,
            CamDir=camDir,
            ProjectionPlane=projectionPlane,
            verbose=True,
            CamAngle=[0, 0, rho],
            rollMode=0,
            edge_on=edge_on,
            treeAllocFac=10,
            xBase=xBase,
            yBase=yBase,
            zBase=zBase,
            return_hsml=False)

        map_maas_g = -5 / 2 * np.log10(image_weight_all_g[:, :, 1] +
                                       1e-15) + 5 * np.log10(
                                           180 * 3600 / np.pi) + 25
        map_maas_r = -5 / 2 * np.log10(image_weight_all_r[:, :, 1] +
                                       1e-15) + 5 * np.log10(
                                           180 * 3600 / np.pi) + 25
        map_maas_i = -5 / 2 * np.log10(image_weight_all_i[:, :, 1] +
                                       1e-15) + 5 * np.log10(
                                           180 * 3600 / np.pi) + 25

    else:
        image_weight_all, image_quant = ir.make_sph_image_new_3d(
            pos,
            mass,
            quant,
            hsml,
            DesNgb=desNGB,
            imsize=args.numpix,
            zpix=1,
            boxsize=args.imsize,
            CamPos=camPos,
            CamDir=camDir,
            ProjectionPlane=projectionPlane,
            verbose=True,
            CamAngle=[0, 0, rho],
            rollMode=0,
            edge_on=edge_on,
            treeAllocFac=10,
            xBase=xBase,
            yBase=yBase,
            zBase=zBase,
            zrange=[-args.zsize, args.zsize])

        # Extract surface density in M_sun [/yr] / kpc^2
        sigma = np.log10(image_weight_all[:, :, 1] + 1e-15) - 6
        if args.ptype == 0 and args.imtype in ['temp']:
            tmap = np.log10(image_quant[:, :, 1])
        elif args.ptype == 0 and args.imtype in ['diffusion_parameters']:
            tmap = image_quant[:, :, 1]

    # -----------------
    # Save image data
    # -----------------

    if save_maps:
        maploc = plotloc + f'{isnap:04d}.hdf5'

        if args.imtype == 'gri' and args.ptype == 4:
            hd.write_data(maploc, 'g_maas', map_maas_g, new=True)
            hd.write_data(maploc, 'r_maas', map_maas_r)
            hd.write_data(maploc, 'i_maas', map_maas_i)
        else:
            hd.write_data(maploc, 'Sigma', sigma, new=True)
            if args.ptype == 0 and args.imtype == 'temp':
                hd.write_data(maploc, 'Temperature', tmap)
            elif args.ptype == 0 and args.imtype == 'diffusion_parameters':
                hd.write_data(maploc, 'DiffusionParameters', tmap)

        hd.write_data(maploc, 'Extent', extent)
        hd.write_attribute(maploc, 'Header', 'CamPos', camPos)
        hd.write_attribute(maploc, 'Header', 'ImSize', args.imsize)
        hd.write_attribute(maploc, 'Header', 'NumPix', args.numpix)
        hd.write_attribute(maploc, 'Header', 'Redshift', 1 / aexp_factor - 1)
        hd.write_attribute(maploc, 'Header', 'AExp', aexp_factor)
        hd.write_attribute(maploc, 'Header', 'Time', time_gyr)

        if bh_mass is not None:
            hd.write_data(maploc,
                          'BH_pos',
                          bh_pos - camPos[None, :],
                          comment='Relative position of BHs')
            hd.write_data(maploc,
                          'BH_mass',
                          bh_mass,
                          comment='Subgrid mass of BHs')
            hd.write_data(
                maploc,
                'BH_maccr',
                bh_maccr,
                comment='Instantaneous BH accretion rate in M_sun/yr')
            hd.write_data(maploc,
                          'BH_id',
                          bh_id,
                          comment='Particle IDs of BHs')
            hd.write_data(maploc,
                          'BH_nseed',
                          bh_nseed,
                          comment='Number of seeds in each BH')
            hd.write_data(maploc,
                          'BH_aexp',
                          bh_ft,
                          comment='Formation scale factor of each BH')

    # -------------
    # Plot image...
    # -------------

    if not args.noplot:

        print("Obtained image, plotting...")
        fig = plt.figure(figsize=(args.inch, args.inch))
        ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
        plt.sca(ax)

        # Option I: we have really few particles. Plot them individually:
        if pos.shape[0] < 32:
            plt.scatter(pos[:, 0] - camPos[0],
                        pos[:, 1] - camPos[1],
                        color='white')

        else:
            # Main plotting regime

            # Case A: gri image -- very different from rest
            if args.ptype == 4 and args.imtype == 'gri':

                vmin = -args.scale[0] + np.array([-0.5, -0.25, 0.0])
                vmax = -args.scale[1] + np.array([-0.5, -0.25, 0.0])

                clmap_rgb = np.zeros((args.numpix, args.numpix, 3))
                clmap_rgb[:, :, 2] = np.clip(
                    ((-map_maas_g) - vmin[0]) / ((vmax[0] - vmin[0])), 0, 1)
                clmap_rgb[:, :, 1] = np.clip(
                    ((-map_maas_r) - vmin[1]) / ((vmax[1] - vmin[1])), 0, 1)
                clmap_rgb[:, :, 0] = np.clip(
                    ((-map_maas_i) - vmin[2]) / ((vmax[2] - vmin[2])), 0, 1)

                im = plt.imshow(clmap_rgb,
                                extent=extent,
                                aspect='equal',
                                interpolation='nearest',
                                origin='lower',
                                alpha=1.0)

            else:

                # Establish image scaling
                if not args.absscale:
                    ind_use = np.nonzero(sigma > 1e-15)
                    vrange = np.percentile(sigma[ind_use], args.scale)
                else:
                    vrange = args.scale
                print(f'Sigma range: {vrange[0]:.4f} -- {vrange[1]:.4f}')

                # Case B: temperature/diffusion parameter image
                if (args.ptype == 0
                        and args.imtype in ['temp', 'diffusion_parameters']
                        and not args.no_double_image):
                    if args.imtype == 'temp':
                        cmap = None
                    elif args.imtype == 'diffusion_parameters':
                        cmap = cmocean.cm.haline
                    clmap_rgb = ir.make_double_image(
                        sigma,
                        tmap,
                        percSigma=vrange,
                        absSigma=True,
                        rangeQuant=args.quantrange,
                        cmap=cmap)

                    im = plt.imshow(clmap_rgb,
                                    extent=extent,
                                    aspect='equal',
                                    interpolation='nearest',
                                    origin='lower',
                                    alpha=1.0)

                else:
                    # Standard sigma images
                    if args.ptype == 0:
                        if args.imtype == 'hi':
                            cmap = plt.cm.bone
                        elif args.imtype == 'sfr':
                            cmap = plt.cm.magma
                        elif args.imtype == 'diffusion_parameters':
                            cmap = cmocean.cm.haline
                        else:
                            cmap = plt.cm.inferno

                    elif args.ptype == 1:
                        cmap = plt.cm.Greys_r
                    elif args.ptype == 4:
                        cmap = plt.cm.bone

                    if args.no_double_image:
                        plotquant = tmap
                        vmin, vmax = args.quantrange[0], args.quantrange[1]
                    else:
                        plotquant = sigma
                        vmin, vmax = vrange[0], vrange[1]

                    im = plt.imshow(plotquant,
                                    cmap=cmap,
                                    extent=extent,
                                    vmin=vmin,
                                    vmax=vmax,
                                    origin='lower',
                                    interpolation='nearest',
                                    aspect='equal')

        # Plot BHs if desired:
        if show_bhs and bh_mass is not None:

            if args.bh_file is not None:
                bh_inds = np.loadtxt(args.bh_file, dtype=int)
            else:
                bh_inds = np.arange(bh_pos.shape[0])

            ind_show = np.nonzero(
                (np.abs(bh_pos[bh_inds, 0] - camPos[0]) < args.imsize)
                & (np.abs(bh_pos[bh_inds, 1] - camPos[1]) < args.imsize)
                & (np.abs(bh_pos[bh_inds, 2] - camPos[2]) < args.zsize)
                & (bh_ft[bh_inds] >= args.bh_ftrange[0])
                & (bh_ft[bh_inds] <= args.bh_ftrange[1])
                & (bh_mass[bh_inds] >= 10.0**args.bh_mrange[0])
                & (bh_mass[bh_inds] <= 10.0**args.bh_mrange[1]))[0]
            ind_show = bh_inds[ind_show]

            if args.bh_quant == 'mass':
                sorter = np.argsort(bh_mass[ind_show])
                sc = plt.scatter(bh_pos[ind_show[sorter], 0] - camPos[0],
                                 bh_pos[ind_show[sorter], 1] - camPos[1],
                                 marker='o',
                                 c=np.log10(bh_mass[ind_show[sorter]]),
                                 edgecolor='grey',
                                 vmin=5.0,
                                 vmax=args.bh_mmax,
                                 s=5.0,
                                 linewidth=0.2)
                bticks = np.linspace(5.0, args.bh_mmax, num=6, endpoint=True)
                blabel = r'log$_{10}$ ($m_\mathrm{BH}$ [M$_\odot$])'

            elif args.bh_quant == 'formation':
                sorter = np.argsort(bh_ft[ind_show])
                sc = plt.scatter(bh_pos[ind_show[sorter], 0] - camPos[0],
                                 bh_pos[ind_show[sorter], 1] - camPos[1],
                                 marker='o',
                                 c=bh_ft[ind_show[sorter]],
                                 edgecolor='grey',
                                 vmin=0,
                                 vmax=1.0,
                                 s=5.0,
                                 linewidth=0.2)
                bticks = np.linspace(0.0, 1.0, num=6, endpoint=True)
                blabel = 'Formation scale factor'

            if args.bhind:
                for ibh in ind_show[sorter]:
                    c = plt.cm.viridis(
                        (np.log10(bh_mass[ibh]) - 5.0) / (args.bh_mmax - 5.0))
                    plt.text(bh_pos[ibh, 0] - camPos[0] + args.imsize / 200,
                             bh_pos[ibh, 1] - camPos[1] + args.imsize / 200,
                             f'{ibh}',
                             color=c,
                             fontsize=4,
                             va='bottom',
                             ha='left')

            if args.draw_hsml:
                phi = np.arange(0, 2.01 * np.pi, 0.01)
                plt.plot(args.hsml * np.cos(phi),
                         args.hsml * np.sin(phi),
                         color='white',
                         linestyle=':',
                         linewidth=0.5)

            # Add colour bar for BH masses
            if args.imtype != 'sfr':
                ax2 = fig.add_axes([0.6, 0.07, 0.35, 0.02])
                ax2.set_xticks([])
                ax2.set_yticks([])
                cbar = plt.colorbar(sc,
                                    cax=ax2,
                                    orientation='horizontal',
                                    ticks=bticks)
                cbar.ax.tick_params(labelsize=8)
                fig.text(0.775,
                         0.1,
                         blabel,
                         rotation=0.0,
                         va='bottom',
                         ha='center',
                         color='white',
                         fontsize=8)

        # Done with main image, some embellishments...
        plt.sca(ax)
        plt.text(-0.045 / 0.05 * args.imsize,
                 0.045 / 0.05 * args.imsize,
                 'z = {:.3f}'.format(1 / aexp_factor - 1),
                 va='center',
                 ha='left',
                 color='white')
        plt.text(-0.045 / 0.05 * args.imsize,
                 0.041 / 0.05 * args.imsize,
                 't = {:.3f} Gyr'.format(time_gyr),
                 va='center',
                 ha='left',
                 color='white',
                 fontsize=8)

        plot_bar()

        # Plot colorbar for SFR if appropriate
        if args.ptype == 0 and args.imtype == 'sfr':
            ax2 = fig.add_axes([0.6, 0.07, 0.35, 0.02])
            ax2.set_xticks([])
            ax2.set_yticks([])

            scc = plt.scatter([-1e10], [-1e10],
                              c=[0],
                              cmap=plt.cm.magma,
                              vmin=vrange[0],
                              vmax=vrange[1])
            cbar = plt.colorbar(scc,
                                cax=ax2,
                                orientation='horizontal',
                                ticks=np.linspace(np.floor(vrange[0]),
                                                  np.ceil(vrange[1]),
                                                  5,
                                                  endpoint=True))
            cbar.ax.tick_params(labelsize=8)
            fig.text(
                0.775,
                0.1,
                r'log$_{10}$ ($\Sigma_\mathrm{SFR}$ [M$_\odot$ yr$^{-1}$ kpc$^{-2}$])',
                rotation=0.0,
                va='bottom',
                ha='center',
                color='white',
                fontsize=8)

        ax.set_xlabel(r'$\Delta x$ [pMpc]')
        ax.set_ylabel(r'$\Delta y$ [pMpc]')

        ax.set_xlim((-args.imsize, args.imsize))
        ax.set_ylim((-args.imsize, args.imsize))

        plt.savefig(plotloc + str(isnap).zfill(4) + '.png',
                    dpi=args.numpix / args.inch)
        plt.close()

    print(f"Finished snapshot {isnap} in {(time.time() - stime):.2f} sec.")
    print(f"Image saved in {plotloc}{isnap:04d}.png")
def process_sim(isim, args):
    """Process one individual simulation"""

    if isinstance(isim, int):
        wdirs = glob.glob(f'{args.basedir}/ID{isim}*/')
        if len(wdirs) != 1:
            set_trace()
        wdir = wdirs[0]
    else:
        wdir = args.basedir + '/' + isim + '/'
        print(f"Analysing simulation {wdir}...")

    bfile = wdir + args.catloc
    if not os.path.isfile(bfile):
        print(f"Could not find BH data file for simulation {isim}.")
        return

    # Copy out the desired fields
    for field in args.bh_fields:
        data = hd.read_data(bfile, field)
        if data is None:
            print(f"Could not load data set '{field}'!")
            set_trace()

        hd.write_data(args.outfile, f'ID{isim}/{field}', data)

    # Copy out metadata fields
    times = hd.read_data(bfile, 'Times')
    redshifts = hd.read_data(bfile, 'Redshifts')
    first_indices = hd.read_data(bfile, 'FirstIndices')
    vr_haloes = hd.read_data(bfile, 'Haloes')
    vr_halo_mstar = hd.read_data(bfile, 'Halo_MStar')
    vr_halo_sfr = hd.read_data(bfile, 'Halo_SFR')
    vr_halo_m200c = hd.read_data(bfile, 'Halo_M200c')
    vr_halo_types = hd.read_data(bfile, 'HaloTypes')
    vr_halo_flag = hd.read_data(bfile, 'Flag_MostMassiveInHalo')

    hd.write_data(args.outfile, f'ID{isim}/Times', times)
    hd.write_data(args.outfile, f'ID{isim}/Redshifts', redshifts)
    hd.write_data(args.outfile, f'ID{isim}/FirstIndices', first_indices)

    if vr_haloes is not None:
        hd.write_data(args.outfile, f'ID{isim}/Haloes', vr_haloes)
        hd.write_data(args.outfile, f'ID{isim}/Halo_MStar', vr_halo_mstar)
        hd.write_data(args.outfile, f'ID{isim}/Halo_SFR', vr_halo_sfr)
        hd.write_data(args.outfile, f'ID{isim}/Halo_M200c', vr_halo_m200c)
        hd.write_data(args.outfile, f'ID{isim}/Halo_Types', vr_halo_types)
        hd.write_data(args.outfile, f'ID{isim}/Halo_FlagMostMassiveBH',
                      vr_halo_flag)