Пример #1
0
def main(cmdargs):
    """Compute the cross-correlation between a catalog of objects and a delta
    field."""
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description=('Compute the cross-correlation between a catalog of '
                     'objects and a delta field.'))

    parser.add_argument('--out',
                        type=str,
                        default=None,
                        required=True,
                        help='Output file name')

    parser.add_argument('--in-dir',
                        type=str,
                        default=None,
                        required=True,
                        help='Directory to delta files')

    parser.add_argument('--from-image',
                        type=str,
                        default=None,
                        required=False,
                        help='Read delta from image format',
                        nargs='*')

    parser.add_argument('--drq',
                        type=str,
                        default=None,
                        required=True,
                        help='Catalog of objects in format selected by mode')

    parser.add_argument('--mode',
                        type=str,
                        required=False,
                        choices=['desi', 'sdss'],
                        default='sdss',
                        help='Mode for reading the catalog (default sdss)')

    parser.add_argument('--rp-min',
                        type=float,
                        default=-200.,
                        required=False,
                        help='Min r-parallel [h^-1 Mpc]')

    parser.add_argument('--rp-max',
                        type=float,
                        default=200.,
                        required=False,
                        help='Max r-parallel [h^-1 Mpc]')

    parser.add_argument('--rt-max',
                        type=float,
                        default=200.,
                        required=False,
                        help='Max r-transverse [h^-1 Mpc]')

    parser.add_argument('--np',
                        type=int,
                        default=100,
                        required=False,
                        help='Number of r-parallel bins')

    parser.add_argument('--nt',
                        type=int,
                        default=50,
                        required=False,
                        help='Number of r-transverse bins')

    parser.add_argument('--z-min-obj',
                        type=float,
                        default=0,
                        required=False,
                        help='Min redshift for object field')

    parser.add_argument('--z-max-obj',
                        type=float,
                        default=10,
                        required=False,
                        help='Max redshift for object field')

    parser.add_argument(
        '--z-cut-min',
        type=float,
        default=0.,
        required=False,
        help=('Use only pairs of forest x object with the mean of the last '
              'absorber redshift and the object redshift larger than '
              'z-cut-min'))

    parser.add_argument(
        '--z-cut-max',
        type=float,
        default=10.,
        required=False,
        help=('Use only pairs of forest x object with the mean of the last '
              'absorber redshift and the object redshift smaller than '
              'z-cut-max'))

    parser.add_argument(
        '--lambda-abs',
        type=str,
        default='LYA',
        required=False,
        help=(
            'Name of the absorption in picca.constants defining the redshift '
            'of the delta'))

    parser.add_argument('--z-ref',
                        type=float,
                        default=2.25,
                        required=False,
                        help='Reference redshift')

    parser.add_argument(
        '--z-evol-del',
        type=float,
        default=2.9,
        required=False,
        help='Exponent of the redshift evolution of the delta field')

    parser.add_argument(
        '--z-evol-obj',
        type=float,
        default=1.,
        required=False,
        help='Exponent of the redshift evolution of the object field')

    parser.add_argument(
        '--fid-Om',
        type=float,
        default=0.315,
        required=False,
        help='Omega_matter(z=0) of fiducial LambdaCDM cosmology')

    parser.add_argument(
        '--fid-Or',
        type=float,
        default=0.,
        required=False,
        help='Omega_radiation(z=0) of fiducial LambdaCDM cosmology')

    parser.add_argument('--fid-Ok',
                        type=float,
                        default=0.,
                        required=False,
                        help='Omega_k(z=0) of fiducial LambdaCDM cosmology')

    parser.add_argument(
        '--fid-wl',
        type=float,
        default=-1.,
        required=False,
        help='Equation of state of dark energy of fiducial LambdaCDM cosmology'
    )

    parser.add_argument('--no-project',
                        action='store_true',
                        required=False,
                        help='Do not project out continuum fitting modes')

    parser.add_argument('--no-remove-mean-lambda-obs',
                        action='store_true',
                        required=False,
                        help='Do not remove mean delta versus lambda_obs')

    parser.add_argument('--nside',
                        type=int,
                        default=16,
                        required=False,
                        help='Healpix nside')

    parser.add_argument('--nproc',
                        type=int,
                        default=None,
                        required=False,
                        help='Number of processors')

    parser.add_argument('--nspec',
                        type=int,
                        default=None,
                        required=False,
                        help='Maximum number of spectra to read')

    parser.add_argument(
        '--shuffle-distrib-obj-seed',
        type=int,
        default=None,
        required=False,
        help=('Shuffle the distribution of objects on the sky following the '
              'given seed. Do not shuffle if None'))

    parser.add_argument(
        '--shuffle-distrib-forest-seed',
        type=int,
        default=None,
        required=False,
        help=('Shuffle the distribution of forests on the sky following the '
              'given seed. Do not shuffle if None'))

    args = parser.parse_args(cmdargs)
    if args.nproc is None:
        args.nproc = cpu_count() // 2

    # setup variables in module xcf
    xcf.r_par_max = args.rp_max
    xcf.r_par_min = args.rp_min
    xcf.z_cut_max = args.z_cut_max
    xcf.z_cut_min = args.z_cut_min
    xcf.r_trans_max = args.rt_max
    xcf.num_bins_r_par = args.np
    xcf.num_bins_r_trans = args.nt
    xcf.nside = args.nside
    xcf.lambda_abs = constants.ABSORBER_IGM[args.lambda_abs]

    # read blinding keyword
    blinding = io.read_blinding(args.in_dir)

    # load fiducial cosmology
    cosmo = constants.Cosmo(Om=args.fid_Om,
                            Or=args.fid_Or,
                            Ok=args.fid_Ok,
                            wl=args.fid_wl,
                            blinding=blinding)

    t0 = time.time()

    # Find the redshift range
    if args.z_min_obj is None:
        r_comov_min = cosmo.get_r_comov(z_min)
        r_comov_min = max(0., r_comov_min + xcf.r_par_min)
        args.z_min_obj = cosmo.distance_to_redshift(r_comov_min)
        userprint("z_min_obj = {}".format(args.z_min_obj), end="")
    if args.z_max_obj is None:
        r_comov_max = cosmo.get_r_comov(z_max)
        r_comov_max = max(0., r_comov_max + xcf.r_par_max)
        args.z_max_obj = cosmo.distance_to_redshift(r_comov_max)
        userprint("z_max_obj = {}".format(args.z_max_obj), end="")

    ### Read objects
    objs, z_min2 = io.read_objects(args.drq,
                                   args.nside,
                                   args.z_min_obj,
                                   args.z_max_obj,
                                   args.z_evol_obj,
                                   args.z_ref,
                                   cosmo,
                                   mode=args.mode)
    xcf.objs = objs

    ### Read deltas
    data, num_data, z_min, z_max = io.read_deltas(args.in_dir,
                                                  args.nside,
                                                  xcf.lambda_abs,
                                                  args.z_evol_del,
                                                  args.z_ref,
                                                  cosmo=cosmo,
                                                  max_num_spec=args.nspec,
                                                  no_project=args.no_project,
                                                  from_image=args.from_image)
    xcf.data = data
    xcf.num_data = num_data
    userprint("")
    userprint("done, npix = {}\n".format(len(data)))
    ### Remove <delta> vs. lambda_obs
    if not args.no_remove_mean_lambda_obs:
        Forest.delta_log_lambda = None
        for healpix in xcf.data:
            for delta in xcf.data[healpix]:
                delta_log_lambda = np.asarray([
                    delta.log_lambda[index] - delta.log_lambda[index - 1]
                    for index in range(1, delta.log_lambda.size)
                ]).min()
                if Forest.delta_log_lambda is None:
                    Forest.delta_log_lambda = delta_log_lambda
                else:
                    Forest.delta_log_lambda = min(delta_log_lambda,
                                                  Forest.delta_log_lambda)
        Forest.log_lambda_min = (np.log10(
            (z_min + 1.) * xcf.lambda_abs) - Forest.delta_log_lambda / 2.)
        Forest.log_lambda_max = (np.log10(
            (z_max + 1.) * xcf.lambda_abs) + Forest.delta_log_lambda / 2.)
        log_lambda, mean_delta, stack_weight = prep_del.stack(
            xcf.data, stack_from_deltas=True)
        del log_lambda, stack_weight
        for healpix in xcf.data:
            for delta in xcf.data[healpix]:
                bins = ((delta.log_lambda - Forest.log_lambda_min) /
                        Forest.delta_log_lambda + 0.5).astype(int)
                delta.delta -= mean_delta[bins]

    # shuffle forests and objects
    if not args.shuffle_distrib_obj_seed is None:
        xcf.objs = utils.shuffle_distrib_forests(objs,
                                                 args.shuffle_distrib_obj_seed)
    if not args.shuffle_distrib_forest_seed is None:
        xcf.data = utils.shuffle_distrib_forests(
            xcf.data, args.shuffle_distrib_forest_seed)

    userprint("")

    # compute maximum angular separation
    xcf.ang_max = utils.compute_ang_max(cosmo, xcf.r_trans_max, z_min, z_min2)

    t1 = time.time()
    userprint(f'picca_xcf.py - Time reading data: {(t1-t0)/60:.3f} minutes')

    # compute correlation function, use pool to parallelize
    xcf.counter = Value('i', 0)
    xcf.lock = Lock()
    cpu_data = {healpix: [healpix] for healpix in data}
    context = multiprocessing.get_context('fork')
    pool = context.Pool(processes=args.nproc)
    correlation_function_data = pool.map(corr_func, sorted(cpu_data.values()))
    pool.close()

    t2 = time.time()
    userprint(
        f'picca_xcf.py - Time computing cross-correlation function: {(t2-t1)/60:.3f} minutes'
    )

    # group data from parallelisation
    correlation_function_data = np.array(correlation_function_data)
    weights_list = correlation_function_data[:, 0, :]
    xi_list = correlation_function_data[:, 1, :]
    r_par_list = correlation_function_data[:, 2, :]
    r_trans_list = correlation_function_data[:, 3, :]
    z_list = correlation_function_data[:, 4, :]
    num_pairs_list = correlation_function_data[:, 5, :].astype(np.int64)
    healpix_list = np.array(sorted(list(cpu_data.keys())))

    w = (weights_list.sum(axis=0) > 0.)
    r_par = (r_par_list * weights_list).sum(axis=0)
    r_par[w] /= weights_list.sum(axis=0)[w]
    r_trans = (r_trans_list * weights_list).sum(axis=0)
    r_trans[w] /= weights_list.sum(axis=0)[w]
    z = (z_list * weights_list).sum(axis=0)
    z[w] /= weights_list.sum(axis=0)[w]
    num_pairs = num_pairs_list.sum(axis=0)

    results = fitsio.FITS(args.out, 'rw', clobber=True)
    header = [{
        'name': 'RPMIN',
        'value': xcf.r_par_min,
        'comment': 'Minimum r-parallel [h^-1 Mpc]'
    }, {
        'name': 'RPMAX',
        'value': xcf.r_par_max,
        'comment': 'Maximum r-parallel [h^-1 Mpc]'
    }, {
        'name': 'RTMAX',
        'value': xcf.r_trans_max,
        'comment': 'Maximum r-transverse [h^-1 Mpc]'
    }, {
        'name': 'NP',
        'value': xcf.num_bins_r_par,
        'comment': 'Number of bins in r-parallel'
    }, {
        'name': 'NT',
        'value': xcf.num_bins_r_trans,
        'comment': 'Number of bins in r-transverse'
    }, {
        'name': 'ZCUTMIN',
        'value': xcf.z_cut_min,
        'comment': 'Minimum redshift of pairs'
    }, {
        'name': 'ZCUTMAX',
        'value': xcf.z_cut_max,
        'comment': 'Maximum redshift of pairs'
    }, {
        'name': 'NSIDE',
        'value': xcf.nside,
        'comment': 'Healpix nside'
    }, {
        'name': 'OMEGAM',
        'value': args.fid_Om,
        'comment': 'Omega_matter(z=0) of fiducial LambdaCDM cosmology'
    }, {
        'name': 'OMEGAR',
        'value': args.fid_Or,
        'comment': 'Omega_radiation(z=0) of fiducial LambdaCDM cosmology'
    }, {
        'name': 'OMEGAK',
        'value': args.fid_Ok,
        'comment': 'Omega_k(z=0) of fiducial LambdaCDM cosmology'
    }, {
        'name':
        'WL',
        'value':
        args.fid_wl,
        'comment':
        'Equation of state of dark energy of fiducial LambdaCDM cosmology'
    }, {
        'name': "BLINDING",
        'value': blinding,
        'comment': 'String specifying the blinding strategy'
    }]
    results.write(
        [r_par, r_trans, z, num_pairs],
        names=['RP', 'RT', 'Z', 'NB'],
        comment=['R-parallel', 'R-transverse', 'Redshift', 'Number of pairs'],
        units=['h^-1 Mpc', 'h^-1 Mpc', '', ''],
        header=header,
        extname='ATTRI')

    header2 = [{
        'name': 'HLPXSCHM',
        'value': 'RING',
        'comment': 'Healpix scheme'
    }]
    da_name = "DA"
    if blinding != "none":
        da_name += "_BLIND"
    results.write([healpix_list, weights_list, xi_list],
                  names=['HEALPID', 'WE', da_name],
                  comment=['Healpix index', 'Sum of weight', 'Correlation'],
                  header=header2,
                  extname='COR')

    results.close()

    t3 = time.time()
    userprint(f'picca_xcf.py - Time total: {(t3-t0)/60:.3f} minutes')
Пример #2
0
def main(cmdargs):
    """Compute the 1D cross-correlation between a catalog of objects and a delta
    field as a function of wavelength ratio"""
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description=('Compute the 1D cross-correlation between a catalog of '
                     'objects and a delta field as a function of wavelength '
                     'ratio'))

    parser.add_argument('--out',
                        type=str,
                        default=None,
                        required=True,
                        help='Output file name')

    parser.add_argument('--in-dir',
                        type=str,
                        default=None,
                        required=True,
                        help='Directory to delta files')

    parser.add_argument('--drq',
                        type=str,
                        default=None,
                        required=True,
                        help='Catalog of objects in DRQ format')

    parser.add_argument(
                        '--mode',
                        type=str,
                        default='sdss',
                        choices=['sdss','desi'],
                        required=False,
                        help='type of catalog supplied, default sdss')

    parser.add_argument('--wr-min',
                        type=float,
                        default=0.9,
                        required=False,
                        help='Min of wavelength ratio')

    parser.add_argument('--wr-max',
                        type=float,
                        default=1.1,
                        required=False,
                        help='Max of wavelength ratio')

    parser.add_argument('--np',
                        type=int,
                        default=100,
                        required=False,
                        help='Number of wavelength ratio bins')

    parser.add_argument('--z-min-obj',
                        type=float,
                        default=0,
                        required=False,
                        help='Min redshift for object field')

    parser.add_argument('--z-max-obj',
                        type=float,
                        default=10,
                        required=False,
                        help='Max redshift for object field')

    parser.add_argument(
        '--z-cut-min',
        type=float,
        default=0.,
        required=False,
        help=('Use only pairs of forest x object with the mean of the last '
              'absorber redshift and the object redshift larger than '
              'z-cut-min'))

    parser.add_argument(
        '--z-cut-max',
        type=float,
        default=10.,
        required=False,
        help=('Use only pairs of forest x object with the mean of the last '
              'absorber redshift and the object redshift smaller than '
              'z-cut-max'))

    parser.add_argument(
        '--lambda-abs',
        type=str,
        default='LYA',
        required=False,
        help=('Name of the absorption in picca.constants defining the redshift '
              'of the delta'))

    parser.add_argument(
        '--lambda-abs-obj',
        type=str,
        default='LYA',
        required=False,
        help=('Name of the absorption in picca.constants the object is '
              'considered as'))

    parser.add_argument('--z-ref',
                        type=float,
                        default=2.25,
                        required=False,
                        help='Reference redshift')

    parser.add_argument(
        '--z-evol-del',
        type=float,
        default=2.9,
        required=False,
        help='Exponent of the redshift evolution of the delta field')

    parser.add_argument(
        '--z-evol-obj',
        type=float,
        default=1.,
        required=False,
        help='Exponent of the redshift evolution of the object field')

    parser.add_argument('--no-project',
                        action='store_true',
                        required=False,
                        help='Do not project out continuum fitting modes')

    parser.add_argument('--no-remove-mean-lambda-obs',
                        action='store_true',
                        required=False,
                        help='Do not remove mean delta versus lambda_obs')

    parser.add_argument('--nside',
                        type=int,
                        default=16,
                        required=False,
                        help='Healpix nside')

    parser.add_argument('--nproc',
                        type=int,
                        default=None,
                        required=False,
                        help='Number of processors')

    parser.add_argument('--nspec',
                        type=int,
                        default=None,
                        required=False,
                        help='Maximum number of spectra to read')

    args = parser.parse_args(cmdargs)
    if args.nproc is None:
        args.nproc = cpu_count() // 2

    # setup variables in module xcf
    xcf.r_par_min = args.wr_min
    xcf.r_par_max = args.wr_max
    xcf.r_trans_max = 1.e-6
    xcf.z_cut_min = args.z_cut_min
    xcf.z_cut_max = args.z_cut_max
    xcf.num_bins_r_par = args.np
    xcf.nt = 1
    xcf.nside = args.nside
    xcf.ang_correlation = True

    lambda_abs = constants.ABSORBER_IGM[args.lambda_abs]

    ### Read deltas
    data, num_data, z_min, z_max = io.read_deltas(args.in_dir,
                                                  args.nside,
                                                  lambda_abs,
                                                  args.z_evol_del,
                                                  args.z_ref,
                                                  cosmo=None,
                                                  max_num_spec=args.nspec,
                                                  no_project=args.no_project)
    xcf.data = data
    xcf.num_data = num_data
    sys.stderr.write("\n")
    userprint("done, npix = {}".format(len(data)))

    ### Remove <delta> vs. lambda_obs
    if not args.no_remove_mean_lambda_obs:
        Forest.delta_log_lambda = None
        for healpix in xcf.data:
            for delta in xcf.data[healpix]:
                delta_log_lambda = np.asarray([
                    delta.log_lambda[index] - delta.log_lambda[index - 1]
                    for index in range(1, delta.log_lambda.size)
                ]).min()
                if Forest.delta_log_lambda is None:
                    Forest.delta_log_lambda = delta_log_lambda
                else:
                    Forest.delta_log_lambda = min(delta_log_lambda,
                                                  Forest.delta_log_lambda)
        Forest.log_lambda_min = (np.log10(
            (z_min + 1.) * lambda_abs) - Forest.delta_log_lambda / 2.)
        Forest.log_lambda_max = (np.log10(
            (z_max + 1.) * lambda_abs) + Forest.delta_log_lambda / 2.)
        log_lambda, mean_delta, stack_weight = prep_del.stack(
            xcf.data, stack_from_deltas=True)
        del log_lambda, stack_weight
        for healpix in xcf.data:
            for delta in xcf.data[healpix]:
                bins = ((delta.log_lambda - Forest.log_lambda_min) /
                        Forest.delta_log_lambda + 0.5).astype(int)
                delta.delta -= mean_delta[bins]

    ### Read objects
    objs, z_min2 = io.read_objects(args.drq,
                                   args.nside,
                                   args.z_min_obj,
                                   args.z_max_obj,
                                   args.z_evol_obj,
                                   args.z_ref,
                                   cosmo=None,
                                   mode=args.mode)
    del z_min2
    xcf.objs = objs
    for healpix in xcf.objs:
        for obj in xcf.objs[healpix]:
            obj.log_lambda = np.log10(
                (1. + obj.z_qso) * constants.ABSORBER_IGM[args.lambda_abs_obj])
    sys.stderr.write("\n")

    # Compute the correlation function, use pool to parallelize
    context = multiprocessing.get_context('fork')
    pool = context.Pool(processes=args.nproc)
    healpixs = [[healpix] for healpix in sorted(data) if healpix in xcf.objs]
    correlation_function_data = pool.map(corr_func, healpixs)
    pool.close()

    # group data from parallelisation
    correlation_function_data = np.array(correlation_function_data)
    weights_list = correlation_function_data[:, 0, :]
    xi_list = correlation_function_data[:, 1, :]
    r_par_list = correlation_function_data[:, 2, :]
    z_list = correlation_function_data[:, 3, :]
    num_pairs_list = correlation_function_data[:, 4, :].astype(np.int64)
    healpix_list = np.array(
        [healpix for healpix in sorted(data) if healpix in xcf.objs])

    w = (weights_list.sum(axis=0) > 0.)
    r_par = (r_par_list * weights_list).sum(axis=0)
    r_par[w] /= weights_list.sum(axis=0)[w]
    z = (z_list * weights_list).sum(axis=0)
    z[w] /= weights_list.sum(axis=0)[w]
    num_pairs = num_pairs_list.sum(axis=0)

    results = fitsio.FITS(args.out, 'rw', clobber=True)
    header = [{
        'name': 'RPMIN',
        'value': xcf.r_par_min,
        'comment': 'Minimum wavelength ratio'
    }, {
        'name': 'RPMAX',
        'value': xcf.r_par_max,
        'comment': 'Maximum wavelength ratio'
    }, {
        'name': 'NP',
        'value': xcf.num_bins_r_par,
        'comment': 'Number of bins in wavelength ratio'
    }, {
        'name': 'ZCUTMIN',
        'value': xcf.z_cut_min,
        'comment': 'Minimum redshift of pairs'
    }, {
        'name': 'ZCUTMAX',
        'value': xcf.z_cut_max,
        'comment': 'Maximum redshift of pairs'
    }, {
        'name': 'NSIDE',
        'value': xcf.nside,
        'comment': 'Healpix nside'
    }]
    results.write([r_par, z, num_pairs],
                  names=['RP', 'Z', 'NB'],
                  units=['', '', ''],
                  comment=['Wavelength ratio', 'Redshift', 'Number of pairs'],
                  header=header,
                  extname='ATTRI')

    header2 = [{
        'name': 'HLPXSCHM',
        'value': 'RING',
        'comment': 'Healpix scheme'
    }]
    results.write([healpix_list, weights_list, xi_list],
                  names=['HEALPID', 'WE', 'DA'],
                  comment=['Healpix index', 'Sum of weight', 'Correlation'],
                  header=header2,
                  extname='COR')

    results.close()
Пример #3
0
    if not args.no_remove_mean_lambda_obs:
        forest.dll = None
        for p in xcf.dels:
            for d in xcf.dels[p]:
                dll = sp.asarray([
                    d.ll[ii] - d.ll[ii - 1] for ii in range(1, d.ll.size)
                ]).min()
                if forest.dll is None:
                    forest.dll = dll
                else:
                    forest.dll = min(dll, forest.dll)
        forest.lmin = sp.log10(
            (zmin_pix + 1.) * xcf.lambda_abs) - forest.dll / 2.
        forest.lmax = sp.log10(
            (zmax_pix + 1.) * xcf.lambda_abs) + forest.dll / 2.
        ll, st, wst = prep_del.stack(xcf.dels, delta=True)
        for p in xcf.dels:
            for d in xcf.dels[p]:
                bins = ((d.ll - forest.lmin) / forest.dll + 0.5).astype(int)
                d.de -= st[bins]

    ### Read objects
    objs,zmin_obj = io.read_objects(args.drq, args.nside, args.z_min_obj, args.z_max_obj,\
                                args.z_evol_obj, args.z_ref,cosmo)
    for i, ipix in enumerate(sorted(objs.keys())):
        for q in objs[ipix]:
            q.ll = sp.log10(
                (1. + q.zqso) * constants.absorber_IGM[args.lambda_abs_obj])
    print("")
    xcf.objs = objs
Пример #4
0
                nqsos = np.zeros((nlss, nlss))

                forest.eta = interp1d(ll,
                                      eta,
                                      fill_value='extrapolate',
                                      kind='nearest')
                forest.var_lss = interp1d(ll,
                                          vlss,
                                          fill_value='extrapolate',
                                          kind='nearest')
                forest.fudge = interp1d(ll,
                                        fudge,
                                        fill_value='extrapolate',
                                        kind='nearest')

    ll_st, st, wst = prep_del.stack(data)

    ### Save iter_out_prefix
    res = fitsio.FITS(args.iter_out_prefix + ".fits.gz", 'rw', clobber=True)
    hd = {}
    hd["NSIDE"] = healpy_nside
    hd["PIXORDER"] = healpy_pix_ordering
    hd["FITORDER"] = args.order
    res.write([ll_st, st, wst],
              names=['loglam', 'stack', 'weight'],
              header=hd,
              extname='STACK')
    res.write([ll, eta, vlss, fudge, nb_pixels],
              names=['loglam', 'eta', 'var_lss', 'fudge', 'nb_pixels'],
              extname='WEIGHT')
    res.write([ll_rest, forest.mean_cont(ll_rest), wmc],
Пример #5
0
def main():
    """Computes the cross-correlation between a catalog of objects and a delta
    field as a function of angle and wavelength ratio"""
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description=('Compute the cross-correlation between a catalog of '
                     'objects and a delta field as a function of angle and '
                     'wavelength ratio'))

    parser.add_argument('--out',
                        type=str,
                        default=None,
                        required=True,
                        help='Output file name')

    parser.add_argument('--in-dir',
                        type=str,
                        default=None,
                        required=True,
                        help='Directory to delta files')

    parser.add_argument('--drq',
                        type=str,
                        default=None,
                        required=True,
                        help='Catalog of objects in DRQ format')

    parser.add_argument('--wr-min',
                        type=float,
                        default=0.9,
                        required=False,
                        help='Min of wavelength ratio')

    parser.add_argument('--wr-max',
                        type=float,
                        default=1.1,
                        required=False,
                        help='Max of wavelength ratio')

    parser.add_argument('--ang-max',
                        type=float,
                        default=0.02,
                        required=False,
                        help='Max angle (rad)')

    parser.add_argument('--np',
                        type=int,
                        default=100,
                        required=False,
                        help='Number of wavelength ratio bins')

    parser.add_argument('--nt',
                        type=int,
                        default=50,
                        required=False,
                        help='Number of angular bins')

    parser.add_argument('--z-min-obj',
                        type=float,
                        default=None,
                        required=False,
                        help='Min redshift for object field')

    parser.add_argument('--z-max-obj',
                        type=float,
                        default=None,
                        required=False,
                        help='Max redshift for object field')

    parser.add_argument(
        '--z-cut-min',
        type=float,
        default=0.,
        required=False,
        help=('Use only pairs of forest x object with the mean of the last '
              'absorber redshift and the object redshift larger than '
              'z-cut-min'))

    parser.add_argument(
        '--z-cut-max',
        type=float,
        default=10.,
        required=False,
        help=('Use only pairs of forest x object with the mean of the last '
              'absorber redshift and the object redshift smaller than '
              'z-cut-max'))

    parser.add_argument(
        '--lambda-abs',
        type=str,
        default='LYA',
        required=False,
        help=('Name of the absorption in picca.constants defining the redshift '
              'of the delta'))

    parser.add_argument(
        '--lambda-abs-obj',
        type=str,
        default='LYA',
        required=False,
        help=('Name of the absorption in picca.constants the object is '
              'considered as'))

    parser.add_argument('--z-ref',
                        type=float,
                        default=2.25,
                        required=False,
                        help='Reference redshift')

    parser.add_argument(
        '--z-evol-del',
        type=float,
        default=2.9,
        required=False,
        help='Exponent of the redshift evolution of the delta field')

    parser.add_argument(
        '--z-evol-obj',
        type=float,
        default=1.,
        required=False,
        help='Exponent of the redshift evolution of the object field')

    parser.add_argument(
        '--fid-Om',
        type=float,
        default=0.315,
        required=False,
        help='Omega_matter(z=0) of fiducial LambdaCDM cosmology')

    parser.add_argument(
        '--fid-Or',
        type=float,
        default=0.,
        required=False,
        help='Omega_radiation(z=0) of fiducial LambdaCDM cosmology')

    parser.add_argument('--fid-Ok',
                        type=float,
                        default=0.,
                        required=False,
                        help='Omega_k(z=0) of fiducial LambdaCDM cosmology')

    parser.add_argument(
        '--fid-wl',
        type=float,
        default=-1.,
        required=False,
        help='Equation of state of dark energy of fiducial LambdaCDM cosmology')

    parser.add_argument('--no-project',
                        action='store_true',
                        required=False,
                        help='Do not project out continuum fitting modes')

    parser.add_argument('--no-remove-mean-lambda-obs',
                        action='store_true',
                        required=False,
                        help='Do not remove mean delta versus lambda_obs')

    parser.add_argument('--nside',
                        type=int,
                        default=16,
                        required=False,
                        help='Healpix nside')

    parser.add_argument('--nproc',
                        type=int,
                        default=None,
                        required=False,
                        help='Number of processors')

    parser.add_argument('--nspec',
                        type=int,
                        default=None,
                        required=False,
                        help='Maximum number of spectra to read')

    args = parser.parse_args()

    if args.nproc is None:
        args.nproc = cpu_count() // 2

    # setup variables in module xcf
    xcf.r_par_min = args.wr_min
    xcf.r_par_max = args.wr_max
    xcf.r_trans_max = args.ang_max
    xcf.z_cut_min = args.z_cut_min
    xcf.z_cut_max = args.z_cut_max
    xcf.num_bins_r_par = args.np
    xcf.num_bins_r_trans = args.nt
    xcf.nside = args.nside
    xcf.ang_correlation = True
    xcf.ang_max = args.ang_max
    xcf.lambda_abs = constants.ABSORBER_IGM[args.lambda_abs]

    # load fiducial cosmology
    cosmo = constants.Cosmo(Om=args.fid_Om,
                            Or=args.fid_Or,
                            Ok=args.fid_Ok,
                            wl=args.fid_wl)

    ### Read deltas
    data, num_data, z_min, z_max = io.read_deltas(
        args.in_dir,
        args.nside,
        constants.ABSORBER_IGM[args.lambda_abs],
        args.z_evol_del,
        args.z_ref,
        cosmo=cosmo,
        max_num_spec=args.nspec,
        no_project=args.no_project)
    xcf.data = data
    xcf.num_data = num_data
    userprint("")
    userprint("done, npix = {}".format(len(data)))

    ### Remove <delta> vs. lambda_obs
    if not args.no_remove_mean_lambda_obs:
        Forest.delta_log_lambda = None
        for healpix in xcf.data:
            for delta in xcf.data[healpix]:
                delta_log_lambda = np.asarray([
                    delta.log_lambda[index] - delta.log_lambda[index - 1]
                    for index in range(1, delta.log_lambda.size)
                ]).min()
                if Forest.delta_log_lambda is None:
                    Forest.delta_log_lambda = delta_log_lambda
                else:
                    Forest.delta_log_lambda = min(delta_log_lambda,
                                                  Forest.delta_log_lambda)
        Forest.log_lambda_min = (np.log10(
            (z_min + 1.) * xcf.lambda_abs) - Forest.delta_log_lambda / 2.)
        Forest.log_lambda_max = (np.log10(
            (z_max + 1.) * xcf.lambda_abs) + Forest.delta_log_lambda / 2.)
        log_lambda, mean_delta, stack_weight = prep_del.stack(
            xcf.data, stack_from_deltas=True)
        del log_lambda, stack_weight
        for healpix in xcf.data:
            for delta in xcf.data[healpix]:
                bins = ((delta.log_lambda - Forest.log_lambda_min) /
                        Forest.delta_log_lambda + 0.5).astype(int)
                delta.delta -= mean_delta[bins]

    ### Read objects
    objs, z_min2 = io.read_objects(args.drq, args.nside, args.z_min_obj,
                                   args.z_max_obj, args.z_evol_obj, args.z_ref,
                                   cosmo)
    del z_min2
    for index, healpix in enumerate(sorted(objs)):
        for obj in objs[healpix]:
            obj.log_lambda = np.log10(
                (1. + obj.z_qso) * constants.ABSORBER_IGM[args.lambda_abs_obj])
    userprint("")
    xcf.objs = objs

    # compute correlation function, use pool to parallelize
    xcf.counter = Value('i', 0)
    xcf.lock = Lock()
    cpu_data = {healpix: [healpix] for healpix in data}
    pool = Pool(processes=args.nproc)
    correlation_function_data = pool.map(corr_func,
                                         sorted(list(cpu_data.values())))
    pool.close()

    # group data from parallelisation
    correlation_function_data = np.array(correlation_function_data)
    weights_list = correlation_function_data[:, 0, :]
    xi_list = correlation_function_data[:, 1, :]
    r_par_list = correlation_function_data[:, 2, :]
    r_trans_list = correlation_function_data[:, 3, :]
    z_list = correlation_function_data[:, 4, :]
    num_pairs_list = correlation_function_data[:, 5, :].astype(np.int64)
    healpix_list = np.array(sorted(list(cpu_data.keys())))

    w = (weights_list.sum(axis=0) > 0.)
    r_par = (r_par_list * weights_list).sum(axis=0)
    r_par[w] /= weights_list.sum(axis=0)[w]
    r_trans = (r_trans_list * weights_list).sum(axis=0)
    r_trans[w] /= weights_list.sum(axis=0)[w]
    z = (z_list * weights_list).sum(axis=0)
    z[w] /= weights_list.sum(axis=0)[w]
    num_pairs = num_pairs_list.sum(axis=0)

    # save results
    results = fitsio.FITS(args.out, 'rw', clobber=True)
    header = [{
        'name': 'RPMIN',
        'value': xcf.r_par_min,
        'comment': 'Minimum wavelength ratio'
    }, {
        'name': 'RPMAX',
        'value': xcf.r_par_max,
        'comment': 'Maximum wavelength ratio'
    }, {
        'name': 'RTMAX',
        'value': xcf.r_trans_max,
        'comment': 'Maximum angle [rad]'
    }, {
        'name': 'NP',
        'value': xcf.num_bins_r_par,
        'comment': 'Number of bins in wavelength ratio'
    }, {
        'name': 'NT',
        'value': xcf.num_bins_r_trans,
        'comment': 'Number of bins in angle'
    }, {
        'name': 'ZCUTMIN',
        'value': xcf.z_cut_min,
        'comment': 'Minimum redshift of pairs'
    }, {
        'name': 'ZCUTMAX',
        'value': xcf.z_cut_max,
        'comment': 'Maximum redshift of pairs'
    }, {
        'name': 'NSIDE',
        'value': xcf.nside,
        'comment': 'Healpix nside'
    }]
    results.write(
        [r_par, r_trans, z, num_pairs],
        names=['RP', 'RT', 'Z', 'NB'],
        units=['', 'rad', '', ''],
        comment=['Wavelength ratio', 'Angle', 'Redshift', 'Number of pairs'],
        header=header,
        extname='ATTRI')

    header2 = [{
        'name': 'HLPXSCHM',
        'value': 'RING',
        'comment': ' Healpix scheme'
    }]
    results.write([healpix_list, weights_list, xi_list],
                  names=['HEALPID', 'WE', 'DA'],
                  comment=['Healpix index', 'Sum of weight', 'Correlation'],
                  header=header2,
                  extname='COR')

    results.close()
Пример #6
0
def main():
    # pylint: disable-msg=too-many-locals,too-many-branches,too-many-statements
    """Computes delta field"""
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description=('Compute the delta field '
                     'from a list of spectra'))

    parser.add_argument('--out-dir',
                        type=str,
                        default=None,
                        required=True,
                        help='Output directory')

    parser.add_argument('--drq',
                        type=str,
                        default=None,
                        required=True,
                        help='Catalog of objects in DRQ format')

    parser.add_argument('--in-dir',
                        type=str,
                        default=None,
                        required=True,
                        help='Directory to spectra files')

    parser.add_argument('--log',
                        type=str,
                        default='input.log',
                        required=False,
                        help='Log input data')

    parser.add_argument('--iter-out-prefix',
                        type=str,
                        default='iter',
                        required=False,
                        help='Prefix of the iteration file')

    parser.add_argument('--mode',
                        type=str,
                        default='pix',
                        required=False,
                        help=('Open mode of the spectra files: pix, spec, '
                              'spcframe, spplate, desi'))

    parser.add_argument('--best-obs',
                        action='store_true',
                        required=False,
                        help=('If mode == spcframe, then use only the best '
                              'observation'))

    parser.add_argument(
        '--single-exp',
        action='store_true',
        required=False,
        help=('If mode == spcframe, then use only one of the '
              'available exposures. If best-obs then choose it '
              'among those contributing to the best obs'))

    parser.add_argument('--zqso-min',
                        type=float,
                        default=None,
                        required=False,
                        help='Lower limit on quasar redshift from drq')

    parser.add_argument('--zqso-max',
                        type=float,
                        default=None,
                        required=False,
                        help='Upper limit on quasar redshift from drq')

    parser.add_argument('--keep-bal',
                        action='store_true',
                        required=False,
                        help='Do not reject BALs in drq')

    parser.add_argument('--bi-max',
                        type=float,
                        required=False,
                        default=None,
                        help=('Maximum CIV balnicity index in drq (overrides '
                              '--keep-bal)'))

    parser.add_argument('--lambda-min',
                        type=float,
                        default=3600.,
                        required=False,
                        help='Lower limit on observed wavelength [Angstrom]')

    parser.add_argument('--lambda-max',
                        type=float,
                        default=5500.,
                        required=False,
                        help='Upper limit on observed wavelength [Angstrom]')

    parser.add_argument('--lambda-rest-min',
                        type=float,
                        default=1040.,
                        required=False,
                        help='Lower limit on rest frame wavelength [Angstrom]')

    parser.add_argument('--lambda-rest-max',
                        type=float,
                        default=1200.,
                        required=False,
                        help='Upper limit on rest frame wavelength [Angstrom]')

    parser.add_argument('--rebin',
                        type=int,
                        default=3,
                        required=False,
                        help=('Rebin wavelength grid by combining this number '
                              'of adjacent pixels (ivar weight)'))

    parser.add_argument('--npix-min',
                        type=int,
                        default=50,
                        required=False,
                        help='Minimum of rebined pixels')

    parser.add_argument('--dla-vac',
                        type=str,
                        default=None,
                        required=False,
                        help='DLA catalog file')

    parser.add_argument('--dla-mask',
                        type=float,
                        default=0.8,
                        required=False,
                        help=('Lower limit on the DLA transmission. '
                              'Transmissions below this number are masked'))

    parser.add_argument('--absorber-vac',
                        type=str,
                        default=None,
                        required=False,
                        help='Absorber catalog file')

    parser.add_argument(
        '--absorber-mask',
        type=float,
        default=2.5,
        required=False,
        help=('Mask width on each side of the absorber central '
              'observed wavelength in units of '
              '1e4*dlog10(lambda)'))

    parser.add_argument('--mask-file',
                        type=str,
                        default=None,
                        required=False,
                        help=('Path to file to mask regions in lambda_OBS and '
                              'lambda_RF. In file each line is: region_name '
                              'region_min region_max (OBS or RF) [Angstrom]'))

    parser.add_argument('--optical-depth',
                        type=str,
                        default=None,
                        required=False,
                        nargs='*',
                        help=('Correct for the optical depth: tau_1 gamma_1 '
                              'absorber_1 tau_2 gamma_2 absorber_2 ...'))

    parser.add_argument('--dust-map',
                        type=str,
                        default=None,
                        required=False,
                        help=('Path to DRQ catalog of objects for dust map to '
                              'apply the Schlegel correction'))

    parser.add_argument(
        '--flux-calib',
        type=str,
        default=None,
        required=False,
        help=('Path to previously produced picca_delta.py file '
              'to correct for multiplicative errors in the '
              'pipeline flux calibration'))

    parser.add_argument(
        '--ivar-calib',
        type=str,
        default=None,
        required=False,
        help=('Path to previously produced picca_delta.py file '
              'to correct for multiplicative errors in the '
              'pipeline inverse variance calibration'))

    parser.add_argument('--eta-min',
                        type=float,
                        default=0.5,
                        required=False,
                        help='Lower limit for eta')

    parser.add_argument('--eta-max',
                        type=float,
                        default=1.5,
                        required=False,
                        help='Upper limit for eta')

    parser.add_argument('--vlss-min',
                        type=float,
                        default=0.,
                        required=False,
                        help='Lower limit for variance LSS')

    parser.add_argument('--vlss-max',
                        type=float,
                        default=0.3,
                        required=False,
                        help='Upper limit for variance LSS')

    parser.add_argument('--delta-format',
                        type=str,
                        default=None,
                        required=False,
                        help='Format for Pk 1D: Pk1D')

    parser.add_argument('--use-ivar-as-weight',
                        action='store_true',
                        default=False,
                        help=('Use ivar as weights (implemented as eta = 1, '
                              'sigma_lss = fudge = 0)'))

    parser.add_argument('--use-constant-weight',
                        action='store_true',
                        default=False,
                        help=('Set all the delta weights to one (implemented '
                              'as eta = 0, sigma_lss = 1, fudge = 0)'))

    parser.add_argument('--order',
                        type=int,
                        default=1,
                        required=False,
                        help=('Order of the log10(lambda) polynomial for the '
                              'continuum fit, by default 1.'))

    parser.add_argument('--nit',
                        type=int,
                        default=5,
                        required=False,
                        help=('Number of iterations to determine the mean '
                              'continuum shape, LSS variances, etc.'))

    parser.add_argument('--nproc',
                        type=int,
                        default=None,
                        required=False,
                        help='Number of processors')

    parser.add_argument('--nspec',
                        type=int,
                        default=None,
                        required=False,
                        help='Maximum number of spectra to read')

    parser.add_argument('--use-mock-continuum',
                        action='store_true',
                        default=False,
                        help='use the mock continuum for computing the deltas')

    parser.add_argument('--spall',
                        type=str,
                        default=None,
                        required=False,
                        help=('Path to spAll file'))

    parser.add_argument('--metadata',
                        type=str,
                        default=None,
                        required=False,
                        help=('Name for table containing forests metadata'))

    t0 = time.time()

    args = parser.parse_args()

    # setup forest class variables
    Forest.log_lambda_min = np.log10(args.lambda_min)
    Forest.log_lambda_max = np.log10(args.lambda_max)
    Forest.log_lambda_min_rest_frame = np.log10(args.lambda_rest_min)
    Forest.log_lambda_max_rest_frame = np.log10(args.lambda_rest_max)
    Forest.rebin = args.rebin
    Forest.delta_log_lambda = args.rebin * 1e-4
    # minumum dla transmission
    Forest.dla_mask_limit = args.dla_mask
    Forest.absorber_mask_width = args.absorber_mask

    # Find the redshift range
    if args.zqso_min is None:
        args.zqso_min = max(0., args.lambda_min / args.lambda_rest_max - 1.)
        userprint("zqso_min = {}".format(args.zqso_min))
    if args.zqso_max is None:
        args.zqso_max = max(0., args.lambda_max / args.lambda_rest_min - 1.)
        userprint("zqso_max = {}".format(args.zqso_max))

    #-- Create interpolators for mean quantities, such as
    #-- Large-scale structure variance : var_lss
    #-- Pipeline ivar correction error: eta
    #-- Pipeline ivar correction term : fudge
    #-- Mean continuum : mean_cont
    log_lambda_temp = (Forest.log_lambda_min + np.arange(2) *
                       (Forest.log_lambda_max - Forest.log_lambda_min))
    log_lambda_rest_frame_temp = (
        Forest.log_lambda_min_rest_frame + np.arange(2) *
        (Forest.log_lambda_max_rest_frame - Forest.log_lambda_min_rest_frame))
    Forest.get_var_lss = interp1d(log_lambda_temp,
                                  0.2 + np.zeros(2),
                                  fill_value="extrapolate",
                                  kind="nearest")
    Forest.get_eta = interp1d(log_lambda_temp,
                              np.ones(2),
                              fill_value="extrapolate",
                              kind="nearest")
    Forest.get_fudge = interp1d(log_lambda_temp,
                                np.zeros(2),
                                fill_value="extrapolate",
                                kind="nearest")
    Forest.get_mean_cont = interp1d(log_lambda_rest_frame_temp,
                                    1 + np.zeros(2))

    #-- Check that the order of the continuum fit is 0 (constant) or 1 (linear).
    if args.order:
        if (args.order != 0) and (args.order != 1):
            userprint(("ERROR : invalid value for order, must be eqal to 0 or"
                       "1. Here order = {:d}").format(args.order))
            sys.exit(12)

    #-- Correct multiplicative pipeline flux calibration
    if args.flux_calib is not None:
        hdu = fitsio.read(args.flux_calib, ext=1)
        stack_log_lambda = hdu['loglam']
        stack_delta = hdu['stack']
        w = (stack_delta != 0.)
        Forest.correct_flux = interp1d(stack_log_lambda[w],
                                       stack_delta[w],
                                       fill_value="extrapolate",
                                       kind="nearest")

    #-- Correct multiplicative pipeline inverse variance calibration
    if args.ivar_calib is not None:
        hdu = fitsio.read(args.ivar_calib, ext=2)
        log_lambda = hdu['loglam']
        eta = hdu['eta']
        Forest.correct_ivar = interp1d(log_lambda,
                                       eta,
                                       fill_value="extrapolate",
                                       kind="nearest")

    ### Apply dust correction
    if not args.dust_map is None:
        userprint("applying dust correction")
        Forest.extinction_bv_map = io.read_dust_map(args.dust_map)

    log_file = open(os.path.expandvars(args.log), 'w')

    # Read data
    (data, num_data, nside,
     healpy_pix_ordering) = io.read_data(os.path.expandvars(args.in_dir),
                                         args.drq,
                                         args.mode,
                                         z_min=args.zqso_min,
                                         z_max=args.zqso_max,
                                         max_num_spec=args.nspec,
                                         log_file=log_file,
                                         keep_bal=args.keep_bal,
                                         bi_max=args.bi_max,
                                         best_obs=args.best_obs,
                                         single_exp=args.single_exp,
                                         pk1d=args.delta_format,
                                         spall=args.spall)

    #-- Add order info
    for pix in data:
        for forest in data[pix]:
            if not forest is None:
                forest.order = args.order

    ### Read masks
    if args.mask_file is not None:
        args.mask_file = os.path.expandvars(args.mask_file)
        try:
            mask = Table.read(args.mask_file,
                              names=('type', 'wave_min', 'wave_max', 'frame'),
                              format='ascii')
            mask['log_wave_min'] = np.log10(mask['wave_min'])
            mask['log_wave_max'] = np.log10(mask['wave_max'])
        except (OSError, ValueError):
            userprint(("ERROR: Error while reading mask_file "
                       "file {}").format(args.mask_file))
            sys.exit(1)
    else:
        mask = Table(names=('type', 'wave_min', 'wave_max', 'frame',
                            'log_wave_min', 'log_wave_max'))

    ### Mask lines
    for healpix in data:
        for forest in data[healpix]:
            forest.mask(mask)

    ### Mask absorbers
    if not args.absorber_vac is None:
        userprint("INFO: Adding absorbers")
        absorbers = io.read_absorbers(args.absorber_vac)
        num_absorbers = 0
        for healpix in data:
            for forest in data[healpix]:
                if forest.thingid in absorbers:
                    for lambda_absorber in absorbers[forest.thingid]:
                        forest.add_absorber(lambda_absorber)
                        num_absorbers += 1
        log_file.write("Found {} absorbers in forests\n".format(num_absorbers))

    ### Add optical depth contribution
    if not args.optical_depth is None:
        userprint(("INFO: Adding {} optical"
                   "depths").format(len(args.optical_depth) // 3))
        assert len(args.optical_depth) % 3 == 0
        for index in range(len(args.optical_depth) // 3):
            tau = float(args.optical_depth[3 * index])
            gamma = float(args.optical_depth[3 * index + 1])
            lambda_rest_frame = constants.ABSORBER_IGM[args.optical_depth[
                3 * index + 2]]
            userprint(
                ("INFO: Adding optical depth for tau = {}, gamma = {}, "
                 "lambda_rest_frame = {} A").format(tau, gamma,
                                                    lambda_rest_frame))
            for healpix in data:
                for forest in data[healpix]:
                    forest.add_optical_depth(tau, gamma, lambda_rest_frame)

    ### Mask DLAs
    if not args.dla_vac is None:
        userprint("INFO: Adding DLAs")
        np.random.seed(0)
        dlas = io.read_dlas(args.dla_vac)
        num_dlas = 0
        for healpix in data:
            for forest in data[healpix]:
                if forest.thingid in dlas:
                    for dla in dlas[forest.thingid]:
                        forest.add_dla(dla[0], dla[1], mask)
                        num_dlas += 1
        log_file.write("Found {} DLAs in forests\n".format(num_dlas))

    ## Apply cuts
    log_file.write(
        ("INFO: Input sample has {} "
         "forests\n").format(np.sum([len(forest)
                                     for forest in data.values()])))
    remove_keys = []
    for healpix in data:
        forests = []
        for forest in data[healpix]:
            if ((forest.log_lambda is None)
                    or len(forest.log_lambda) < args.npix_min):
                log_file.write(("INFO: Rejected {} due to forest too "
                                "short\n").format(forest.thingid))
                continue

            if np.isnan((forest.flux * forest.ivar).sum()):
                log_file.write(("INFO: Rejected {} due to nan "
                                "found\n").format(forest.thingid))
                continue

            if (args.use_constant_weight
                    and (forest.flux.mean() <= 0.0 or forest.mean_snr <= 1.0)):
                log_file.write(("INFO: Rejected {} due to negative mean or "
                                "too low SNR found\n").format(forest.thingid))
                continue

            forests.append(forest)
            log_file.write("{} {}-{}-{} accepted\n".format(
                forest.thingid, forest.plate, forest.mjd, forest.fiberid))
        data[healpix][:] = forests
        if len(data[healpix]) == 0:
            remove_keys += [healpix]

    for healpix in remove_keys:
        del data[healpix]

    num_forests = np.sum([len(forest) for forest in data.values()])
    log_file.write(("INFO: Remaining sample has {} "
                    "forests\n").format(num_forests))
    userprint(f"Remaining sample has {num_forests} forests")

    # Sanity check: all forests must have the attribute log_lambda
    for healpix in data:
        for forest in data[healpix]:
            assert forest.log_lambda is not None

    t1 = time.time()
    tmin = (t1 - t0) / 60
    userprint('INFO: time elapsed to read data', tmin, 'minutes')

    # compute fits to the forests iteratively
    # (see equations 2 to 4 in du Mas des Bourboux et al. 2020)
    num_iterations = args.nit
    for iteration in range(num_iterations):
        context = multiprocessing.get_context('fork')
        pool = context.Pool(processes=args.nproc)
        userprint(
            f"Continuum fitting: starting iteration {iteration} of {num_iterations}"
        )

        #-- Sorting healpix pixels before giving to pool (for some reason)
        pixels = np.array([k for k in data])
        sort = pixels.argsort()
        sorted_data = [data[k] for k in pixels[sort]]
        data_fit_cont = pool.map(cont_fit, sorted_data)
        for index, healpix in enumerate(pixels[sort]):
            data[healpix] = data_fit_cont[index]

        userprint(
            f"Continuum fitting: ending iteration {iteration} of {num_iterations}"
        )

        pool.close()

        if iteration < num_iterations - 1:
            #-- Compute mean continuum (stack in rest-frame)
            (log_lambda_rest_frame, mean_cont,
             mean_cont_weight) = prep_del.compute_mean_cont(data)
            w = mean_cont_weight > 0.
            log_lambda_cont = log_lambda_rest_frame[w]
            new_cont = Forest.get_mean_cont(log_lambda_cont) * mean_cont[w]
            Forest.get_mean_cont = interp1d(log_lambda_cont,
                                            new_cont,
                                            fill_value="extrapolate")

            #-- Compute observer-frame mean quantities (var_lss, eta, fudge)
            if not (args.use_ivar_as_weight or args.use_constant_weight):
                (log_lambda, eta, var_lss, fudge, num_pixels, var_pipe_values,
                 var_delta, var2_delta, count, num_qso, chi2_in_bin, error_eta,
                 error_var_lss, error_fudge) = prep_del.compute_var_stats(
                     data, (args.eta_min, args.eta_max),
                     (args.vlss_min, args.vlss_max))
                w = num_pixels > 0
                Forest.get_eta = interp1d(log_lambda[w],
                                          eta[w],
                                          fill_value="extrapolate",
                                          kind="nearest")
                Forest.get_var_lss = interp1d(log_lambda[w],
                                              var_lss[w],
                                              fill_value="extrapolate",
                                              kind="nearest")
                Forest.get_fudge = interp1d(log_lambda[w],
                                            fudge[w],
                                            fill_value="extrapolate",
                                            kind="nearest")
            else:
                num_bins = 10  # this value is arbitrary
                log_lambda = (
                    Forest.log_lambda_min + (np.arange(num_bins) + .5) *
                    (Forest.log_lambda_max - Forest.log_lambda_min) / num_bins)

                if args.use_ivar_as_weight:
                    userprint(("INFO: using ivar as weights, skipping eta, "
                               "var_lss, fudge fits"))
                    eta = np.ones(num_bins)
                    var_lss = np.zeros(num_bins)
                    fudge = np.zeros(num_bins)
                else:
                    userprint(("INFO: using constant weights, skipping eta, "
                               "var_lss, fudge fits"))
                    eta = np.zeros(num_bins)
                    var_lss = np.ones(num_bins)
                    fudge = np.zeros(num_bins)

                error_eta = np.zeros(num_bins)
                error_var_lss = np.zeros(num_bins)
                error_fudge = np.zeros(num_bins)
                chi2_in_bin = np.zeros(num_bins)

                num_pixels = np.zeros(num_bins)
                var_pipe_values = np.zeros(num_bins)
                var_delta = np.zeros((num_bins, num_bins))
                var2_delta = np.zeros((num_bins, num_bins))
                count = np.zeros((num_bins, num_bins))
                num_qso = np.zeros((num_bins, num_bins))

                Forest.get_eta = interp1d(log_lambda,
                                          eta,
                                          fill_value='extrapolate',
                                          kind='nearest')
                Forest.get_var_lss = interp1d(log_lambda,
                                              var_lss,
                                              fill_value='extrapolate',
                                              kind='nearest')
                Forest.get_fudge = interp1d(log_lambda,
                                            fudge,
                                            fill_value='extrapolate',
                                            kind='nearest')

    ### Read metadata from forests and export it
    if not args.metadata is None:
        tab_cont = get_metadata(data)
        tab_cont.write(args.metadata, format="fits", overwrite=True)

    stack_log_lambda, stack_delta, stack_weight = prep_del.stack(data)

    ### Save iter_out_prefix
    results = fitsio.FITS(args.iter_out_prefix + ".fits.gz",
                          'rw',
                          clobber=True)
    header = {}
    header["NSIDE"] = nside
    header["PIXORDER"] = healpy_pix_ordering
    header["FITORDER"] = args.order
    results.write([stack_log_lambda, stack_delta, stack_weight],
                  names=['loglam', 'stack', 'weight'],
                  header=header,
                  extname='STACK')
    results.write([log_lambda, eta, var_lss, fudge, num_pixels],
                  names=['loglam', 'eta', 'var_lss', 'fudge', 'nb_pixels'],
                  extname='WEIGHT')
    results.write([
        log_lambda_rest_frame,
        Forest.get_mean_cont(log_lambda_rest_frame), mean_cont_weight
    ],
                  names=['loglam_rest', 'mean_cont', 'weight'],
                  extname='CONT')
    var_pipe_values = np.broadcast_to(var_pipe_values.reshape(1, -1),
                                      var_delta.shape)
    results.write(
        [var_pipe_values, var_delta, var2_delta, count, num_qso, chi2_in_bin],
        names=['var_pipe', 'var_del', 'var2_del', 'count', 'nqsos', 'chi2'],
        extname='VAR')
    results.close()

    ### Compute deltas and format them
    get_stack_delta = interp1d(stack_log_lambda[stack_weight > 0.],
                               stack_delta[stack_weight > 0.],
                               kind="nearest",
                               fill_value="extrapolate")
    deltas = {}
    data_bad_cont = []
    for healpix in sorted(data.keys()):
        for forest in data[healpix]:
            if not forest.bad_cont is None:
                continue
            #-- Compute delta field from flux, continuum and various quantites
            get_delta_from_forest(forest, get_stack_delta, Forest.get_var_lss,
                                  Forest.get_eta, Forest.get_fudge,
                                  args.use_mock_continuum)
            if healpix in deltas:
                deltas[healpix].append(forest)
            else:
                deltas[healpix] = [forest]
        data_bad_cont = data_bad_cont + [
            forest for forest in data[healpix] if forest.bad_cont is not None
        ]

    for forest in data_bad_cont:
        log_file.write("INFO: Rejected {} due to {}\n".format(
            forest.thingid, forest.bad_cont))

    log_file.write(
        ("INFO: Accepted sample has {}"
         "forests\n").format(np.sum([len(p) for p in deltas.values()])))

    t2 = time.time()
    tmin = (t2 - t1) / 60
    userprint('INFO: time elapsed to fit continuum', tmin, 'minutes')

    ### Save delta
    for healpix in sorted(deltas.keys()):

        if args.delta_format == 'Pk1D_ascii':
            results = open(args.out_dir + "/delta-{}".format(healpix) + ".txt",
                           'w')
            for delta in deltas[healpix]:
                num_pixels = len(delta.delta)
                if args.mode == 'desi':
                    delta_log_lambda = (
                        (delta.log_lambda[-1] - delta.log_lambda[0]) /
                        float(len(delta.log_lambda) - 1))
                else:
                    delta_log_lambda = delta.delta_log_lambda
                line = '{} {} {} '.format(delta.plate, delta.mjd,
                                          delta.fiberid)
                line += '{} {} {} '.format(delta.ra, delta.dec, delta.z_qso)
                line += '{} {} {} {} {} '.format(delta.mean_z, delta.mean_snr,
                                                 delta.mean_reso,
                                                 delta_log_lambda, num_pixels)
                for index in range(num_pixels):
                    line += '{} '.format(delta.delta[index])
                for index in range(num_pixels):
                    line += '{} '.format(delta.log_lambda[index])
                for index in range(num_pixels):
                    line += '{} '.format(delta.ivar[index])
                for index in range(num_pixels):
                    line += '{} '.format(delta.exposures_diff[index])
                line += ' \n'
                results.write(line)

            results.close()

        else:
            results = fitsio.FITS(args.out_dir + "/delta-{}".format(healpix) +
                                  ".fits.gz",
                                  'rw',
                                  clobber=True)
            for delta in deltas[healpix]:
                header = [
                    {
                        'name': 'RA',
                        'value': delta.ra,
                        'comment': 'Right Ascension [rad]'
                    },
                    {
                        'name': 'DEC',
                        'value': delta.dec,
                        'comment': 'Declination [rad]'
                    },
                    {
                        'name': 'Z',
                        'value': delta.z_qso,
                        'comment': 'Redshift'
                    },
                    {
                        'name':
                        'PMF',
                        'value':
                        '{}-{}-{}'.format(delta.plate, delta.mjd,
                                          delta.fiberid)
                    },
                    {
                        'name': 'THING_ID',
                        'value': delta.thingid,
                        'comment': 'Object identification'
                    },
                    {
                        'name': 'PLATE',
                        'value': delta.plate
                    },
                    {
                        'name': 'MJD',
                        'value': delta.mjd,
                        'comment': 'Modified Julian date'
                    },
                    {
                        'name': 'FIBERID',
                        'value': delta.fiberid
                    },
                    {
                        'name': 'ORDER',
                        'value': delta.order,
                        'comment': 'Order of the continuum fit'
                    },
                ]

                if args.delta_format == 'Pk1D':
                    header += [
                        {
                            'name': 'MEANZ',
                            'value': delta.mean_z,
                            'comment': 'Mean redshift'
                        },
                        {
                            'name': 'MEANRESO',
                            'value': delta.mean_reso,
                            'comment': 'Mean resolution'
                        },
                        {
                            'name': 'MEANSNR',
                            'value': delta.mean_snr,
                            'comment': 'Mean SNR'
                        },
                    ]
                    if args.mode == 'desi':
                        delta_log_lambda = (
                            (delta.log_lambda[-1] - delta.log_lambda[0]) /
                            float(len(delta.log_lambda) - 1))
                    else:
                        delta_log_lambda = delta.delta_log_lambda
                    header += [{
                        'name': 'DLL',
                        'value': delta_log_lambda,
                        'comment': 'Loglam bin size [log Angstrom]'
                    }]
                    exposures_diff = delta.exposures_diff
                    if exposures_diff is None:
                        exposures_diff = delta.log_lambda * 0

                    cols = [
                        delta.log_lambda, delta.delta, delta.ivar,
                        exposures_diff
                    ]
                    names = ['LOGLAM', 'DELTA', 'IVAR', 'DIFF']
                    units = ['log Angstrom', '', '', '']
                    comments = [
                        'Log lambda', 'Delta field', 'Inverse variance',
                        'Difference'
                    ]
                else:
                    cols = [
                        delta.log_lambda, delta.delta, delta.weights,
                        delta.cont
                    ]
                    names = ['LOGLAM', 'DELTA', 'WEIGHT', 'CONT']
                    units = ['log Angstrom', '', '', '']
                    comments = [
                        'Log lambda', 'Delta field', 'Pixel weights',
                        'Continuum'
                    ]

                results.write(cols,
                              names=names,
                              header=header,
                              comment=comments,
                              units=units,
                              extname=str(delta.thingid))

            results.close()

    t3 = time.time()
    tmin = (t3 - t2) / 60
    userprint('INFO: time elapsed to write deltas', tmin, 'minutes')
    ttot = (t3 - t0) / 60
    userprint('INFO: total elapsed time', ttot, 'minutes')

    log_file.close()